From 9c752ff0d3a9e98c550e44083af114cc28a3b907 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 1 Mar 2024 11:52:17 +0200 Subject: [PATCH 01/24] Trying IQ3_S without a lookup table --- ggml-cuda.cu | 185 ++++++++++++++---------- ggml-quants.c | 391 +++++++++++++++++++++++++++++++++++--------------- 2 files changed, 384 insertions(+), 192 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index dfd28df62..a58214557 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2008,72 +2008,72 @@ static const __device__ uint32_t iq3xxs_grid[256] = { 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, }; -static const __device__ uint32_t iq3xs_grid[512] = { - 0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14, - 0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414, - 0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24, - 0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c, - 0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c, - 0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34, - 0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c, - 0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414, - 0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c, - 0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404, - 0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434, - 0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c, - 0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404, - 0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414, - 0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414, - 0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404, - 0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c, - 0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c, - 0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404, - 0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e, - 0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14, - 0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c, - 0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424, - 0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c, - 0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c, - 0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e, - 0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e, - 0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e, - 0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424, - 0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e, - 0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424, - 0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404, - 0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c, - 0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e, - 0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c, - 0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c, - 0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c, - 0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404, - 0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04, - 0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c, - 0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414, - 0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c, - 0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c, - 0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424, - 0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c, - 0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c, - 0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414, - 0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c, - 0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e, - 0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04, - 0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424, - 0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14, - 0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34, - 0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c, - 0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434, - 0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c, - 0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424, - 0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24, - 0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24, - 0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e, - 0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c, - 0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c, - 0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c, - 0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404, -}; +//static const __device__ uint32_t iq3xs_grid[512] = { +//0x04040404, 0x04142c14, 0x042c2424, 0x0404143c, 0x04140c0c, 0x042c0424, 0x04043434, 0x041c240c, +//0x04341c1c, 0x040c0c34, 0x041c0404, 0x0434341c, 0x040c2c2c, 0x04241c04, 0x043c1414, 0x0414042c, +//0x04243c04, 0x04042c14, 0x04142424, 0x042c143c, 0x04040c0c, 0x0c1c0424, 0x0c2c3434, 0x0c04240c, +//0x0c1c141c, 0x0c340c34, 0x0c0c0404, 0x0c24341c, 0x0c34242c, 0x0c0c1c04, 0x0c240c14, 0x0c3c042c, +//0x0c143404, 0x0c2c2c14, 0x14041c24, 0x1414143c, 0x142c040c, 0x14043c24, 0x141c2c34, 0x1434240c, +//0x140c141c, 0x141c0c34, 0x14340404, 0x140c341c, 0x1424242c, 0x143c1c04, 0x14140c14, 0x1424042c, +//0x1c043404, 0x1c142414, 0x1c2c1c24, 0x1c040c3c, 0x1c1c040c, 0x1c2c3424, 0x1c042c34, 0x1c1c1c0c, +//0x1c34141c, 0x1c0c0434, 0x1c243c04, 0x1c342c1c, 0x1c0c242c, 0x1c241404, 0x243c0c14, 0x2414042c, +//0x242c3404, 0x24042414, 0x24141c24, 0x242c0c3c, 0x2404040c, 0x241c3424, 0x24342434, 0x24041c0c, +//0x241c0c1c, 0x24340434, 0x240c3404, 0x2c242c1c, 0x2c3c1c2c, 0x2c141404, 0x2c240414, 0x2c043c2c, +//0x2c142c04, 0x2c2c2414, 0x2c041424, 0x2c1c0c3c, 0x2c2c040c, 0x2c043424, 0x2c1c2434, 0x2c341c0c, +//0x2c0c0c1c, 0x34240434, 0x34343404, 0x340c2c1c, 0x34241c2c, 0x343c1404, 0x34140414, 0x342c342c, +//0x34042c04, 0x34141c14, 0x342c1424, 0x3404043c, 0x341c3c0c, 0x34342c24, 0x3c042434, 0x3c1c140c, +//0x3c340c1c, 0x3c0c0434, 0x3c243404, 0x3c3c241c, 0x3c0c1c2c, 0x04240c04, 0x04040414, 0x0414342c, +//0x042c2c04, 0x04041c14, 0x041c1424, 0x042c043c, 0x04043c0c, 0x041c2c24, 0x04341c34, 0x040c140c, +//0x0424041c, 0x04343c34, 0x040c2c04, 0x0424241c, 0x043c142c, 0x04140c04, 0x042c0414, 0x0404342c, +//0x04142404, 0x042c1c14, 0x0c040c24, 0x0c1c043c, 0x0c34340c, 0x0c042c24, 0x0c1c1c34, 0x0c34140c, +//0x0c0c041c, 0x0c243c34, 0x0c3c2c04, 0x0c0c241c, 0x0c24142c, 0x0c040404, 0x0c143c14, 0x142c2c2c, +//0x14042404, 0x14141414, 0x142c0c24, 0x1404043c, 0x141c340c, 0x14342424, 0x140c1c34, 0x14240c0c, +//0x1434041c, 0x140c3434, 0x14242c04, 0x143c1c1c, 0x1414142c, 0x1c2c0404, 0x1c043c14, 0x1c142c2c, +//0x1c2c2404, 0x1c041414, 0x1c1c0c24, 0x1c343c3c, 0x1c042c0c, 0x1c1c2424, 0x1c341434, 0x1c0c0c0c, +//0x1c24041c, 0x1c3c3434, 0x240c2404, 0x24241c1c, 0x24040c2c, 0x24140404, 0x242c3414, 0x24042c2c, +//0x24141c04, 0x242c1414, 0x24040424, 0x241c3c3c, 0x24342c0c, 0x240c2424, 0x241c1434, 0x24340c0c, +//0x2c0c041c, 0x2c243434, 0x2c3c2404, 0x2c14141c, 0x2c2c0c2c, 0x2c040404, 0x2c143414, 0x2c2c242c, +//0x2c041c04, 0x2c1c0c14, 0x2c340424, 0x2c04343c, 0x2c1c2c0c, 0x2c341c24, 0x340c1434, 0x3424040c, +//0x343c3c1c, 0x340c2c34, 0x34242404, 0x3404141c, 0x34140c2c, 0x342c0404, 0x34043414, 0x3414242c, +//0x342c1c04, 0x34040c14, 0x341c0424, 0x3c34343c, 0x3c0c240c, 0x3c1c1c24, 0x3c340c34, 0x3c0c040c, +//0x3c24341c, 0x3c3c2c34, 0x04141c04, 0x0424141c, 0x0404042c, 0x04143c04, 0x042c2c14, 0x0404242c, +//0x041c1404, 0x04340c14, 0x04040424, 0x041c343c, 0x0434240c, 0x040c1c24, 0x04240c34, 0x043c040c, +//0x040c341c, 0x04242434, 0x04041c04, 0x04140c1c, 0x042c042c, 0x04043404, 0x0c142c14, 0x0c2c1c2c, +//0x0c041404, 0x0c1c0414, 0x0c343c24, 0x0c0c2c3c, 0x0c1c240c, 0x0c341424, 0x0c0c0c34, 0x0c24040c, +//0x0c3c341c, 0x0c142434, 0x0c241c04, 0x0c040c1c, 0x1414042c, 0x142c3404, 0x14042c14, 0x141c1c2c, +//0x142c1404, 0x14040414, 0x141c3424, 0x14342c3c, 0x140c1c0c, 0x14241424, 0x143c0434, 0x140c3c0c, +//0x14242c1c, 0x1c042434, 0x1c141404, 0x1c2c0c1c, 0x1c04042c, 0x1c143404, 0x1c2c2414, 0x1c041c2c, +//0x1c1c0c04, 0x1c340414, 0x1c0c3424, 0x1c1c2c3c, 0x1c341c0c, 0x1c0c1424, 0x1c240434, 0x243c3c0c, +//0x24142c1c, 0x24241c34, 0x24041404, 0x2414041c, 0x242c3c2c, 0x24042c04, 0x241c2414, 0x242c142c, +//0x24040c04, 0x241c0414, 0x24343424, 0x240c243c, 0x24241c0c, 0x2c340c24, 0x2c0c0434, 0x2c24340c, +//0x2c3c2c1c, 0x2c141c34, 0x2c2c1404, 0x2c04041c, 0x2c143c2c, 0x2c2c2c04, 0x2c042414, 0x2c1c142c, +//0x2c340404, 0x2c0c3c14, 0x341c2c24, 0x3434243c, 0x340c140c, 0x34240c24, 0x343c0434, 0x3414340c, +//0x3424241c, 0x34041c34, 0x34140c04, 0x342c041c, 0x3404342c, 0x341c2c04, 0x342c1c14, 0x3404142c, +//0x3c1c0404, 0x3c343c14, 0x3c0c2c24, 0x3c24243c, 0x3c34140c, 0x3c0c0c24, 0x3c243c34, 0x043c2c0c, +//0x0414241c, 0x042c1434, 0x04040c04, 0x0414041c, 0x042c342c, 0x04042404, 0x041c1c14, 0x04340c2c, +//0x040c0404, 0x041c3414, 0x04342c24, 0x040c1c3c, 0x0424140c, 0x043c0424, 0x04143c34, 0x04242c0c, +//0x0404241c, 0x04141434, 0x042c0c04, 0x0c04041c, 0x0c1c342c, 0x0c2c2404, 0x0c041414, 0x0c1c0c2c, +//0x0c340404, 0x0c0c3414, 0x0c242424, 0x0c341c3c, 0x0c0c0c0c, 0x0c240424, 0x0c3c3434, 0x0c142c0c, +//0x0c2c1c1c, 0x14041434, 0x14140404, 0x142c3c1c, 0x14042c2c, 0x141c2404, 0x14341414, 0x14040c2c, +//0x141c0404, 0x14343414, 0x140c2424, 0x14241c3c, 0x143c0c0c, 0x14140424, 0x1c243434, 0x1c04240c, +//0x1c141c1c, 0x1c2c0c34, 0x1c040404, 0x1c1c341c, 0x1c2c2c2c, 0x1c041c04, 0x1c1c1414, 0x1c34042c, +//0x1c0c3c04, 0x1c242c14, 0x1c342424, 0x1c0c143c, 0x24240c0c, 0x243c0424, 0x24143434, 0x242c240c, +//0x24041c1c, 0x24140c34, 0x242c0404, 0x2404341c, 0x241c242c, 0x24341c04, 0x24040c14, 0x241c042c, +//0x24343404, 0x2c0c2c14, 0x2c241c24, 0x2c3c143c, 0x2c0c040c, 0x2c243c24, 0x2c042c34, 0x2c14240c, +//0x2c2c141c, 0x2c040c34, 0x2c1c0404, 0x2c2c341c, 0x2c04242c, 0x2c1c1c04, 0x2c340c14, 0x340c042c, +//0x34243404, 0x34342c14, 0x340c1c24, 0x34240c3c, 0x343c040c, 0x34143424, 0x342c2c34, 0x34041c0c, +//0x3414141c, 0x342c0434, 0x34043c04, 0x341c2c1c, 0x3434242c, 0x3c041404, 0x3c1c0c14, 0x3c34042c, +//0x3c0c3404, 0x3c242414, 0x3c3c1c24, 0x040c0c3c, 0x0424040c, 0x04043424, 0x04142c34, 0x042c1c0c, +//0x0404141c, 0x04140434, 0x042c3c04, 0x04042c1c, 0x041c1c2c, 0x04341404, 0x040c0414, 0x041c3c2c, +//0x04342c04, 0x040c2414, 0x04241424, 0x043c0c3c, 0x0414040c, 0x042c3424, 0x04042434, 0x04141c0c, +//0x0c2c0c1c, 0x0c040434, 0x0c1c3404, 0x0c342c1c, 0x0c041c2c, 0x0c1c1404, 0x0c340414, 0x0c0c3c2c, +//0x0c242c04, 0x0c3c2414, 0x0c0c1424, 0x0c24043c, 0x0c043c0c, 0x14142c24, 0x142c2434, 0x1404140c, +//0x14140c1c, 0x142c0434, 0x14043404, 0x141c241c, 0x14341c2c, 0x140c0c04, 0x141c0414, 0x1434342c, +//0x140c2c04, 0x14241c14, 0x143c1424, 0x1c14043c, 0x1c243c0c, 0x1c042c24, 0x1c142434, 0x1c2c140c, +//0x1c040c1c, 0x1c1c3c34, 0x1c342c04, 0x1c04241c, 0x1c1c142c, 0x1c340c04, 0x1c0c0414, 0x1c24342c, +//0x1c3c2404, 0x240c1c14, 0x24240c24, 0x2404043c, 0x2414340c, 0x242c2c24, 0x24041c34, 0x2414140c, +//0x242c041c, 0x24043c34, 0x241c2c04, 0x2434241c, 0x240c142c, 0x241c0c04, 0x2c340414, 0x2c0c342c, +//}; static const __device__ uint64_t iq1s_grid[512] = { @@ -2370,6 +2370,12 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds } +//#define IQ3S_MULTIPLIER 2469109 +//#define IQ3S_MULTIPLIER 746226 +//#define IQ3S_MULTIPLIER 717154 +#define IQ3S_MULTIPLIER 677595 +//static const __device__ uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15}; + template static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { @@ -2382,14 +2388,25 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_ const int ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint8_t * qs = x[i].qs + 8*ib; - const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256))); - const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)) * 0.5f; + int32_t aux32[2]; + const uint8_t * grid = (const uint8_t *)aux32; + const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)); const uint8_t signs = x[i].signs[4*ib + il]; - for (int j = 0; j < 4; ++j) { - y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); - y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); + aux32[0] = ((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; + aux32[1] = ((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; + for (int j = 0; j < 8; ++j) { + //y[j] = d * (2*((grid[j]-1)/2) + 1) * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + //y[j] = d * iq3s_values[grid[j]] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + y[j] = d * (2*(((grid[j]+1)/2) & 7) + 1) * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } +// const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256))); +// const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256))); +// const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)) * 0.5f; +// const uint8_t signs = x[i].signs[4*ib + il]; +// for (int j = 0; j < 4; ++j) { +// y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); +// y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); +// } #else assert(false); #endif @@ -5196,22 +5213,36 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( #if QK_K == 256 const block_iq3_s * bq2 = (const block_iq3_s *) vbq; + uint32_t aux32[2]; + const uint8_t * grid = (const uint8_t *)aux32; + const int ib32 = iqs; const uint8_t * qs = bq2->qs + 8*ib32; const int8_t * q8 = bq8_1[ib32].qs; int sumi = 0; for (int l = 0; l < 4; ++l) { - const uint32_t * grid1 = iq3xs_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)); - const uint32_t * grid2 = iq3xs_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)); + aux32[0] = ((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; + aux32[1] = ((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; + aux32[0] = __vadd4(((__vadd4(aux32[0], 0x01010101) >> 1) & 0x07070707) << 1, 0x01010101); + aux32[1] = __vadd4(((__vadd4(aux32[1], 0x01010101) >> 1) & 0x07070707) << 1, 0x01010101); uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201); uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201); - const int grid_l = __vsub4(grid1[0] ^ signs0, signs0); - const int grid_h = __vsub4(grid2[0] ^ signs1, signs1); + const int grid_l = __vsub4(aux32[0] ^ signs0, signs0); + const int grid_h = __vsub4(aux32[1] ^ signs1, signs1); sumi = __dp4a(grid_l, *((int *)q8+0), sumi); sumi = __dp4a(grid_h, *((int *)q8+1), sumi); + //const uint32_t * grid1 = iq3xs_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)); + //const uint32_t * grid2 = iq3xs_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)); + //uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201); + //uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201); + //const int grid_l = __vsub4(grid1[0] ^ signs0, signs0); + //const int grid_h = __vsub4(grid2[0] ^ signs1, signs1); + //sumi = __dp4a(grid_l, *((int *)q8+0), sumi); + //sumi = __dp4a(grid_h, *((int *)q8+1), sumi); q8 += 8; } - const float d = (float)bq2->d * (0.5f + ((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds) * 0.5f; + //const float d = (float)bq2->d * (0.5f + ((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds) * 0.5f; + const float d = (float)bq2->d * (1 + 2*((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds); return d * sumi; #else assert(false); diff --git a/ggml-quants.c b/ggml-quants.c index f73d17ce2..9f9d299fe 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -3790,70 +3790,70 @@ static const uint32_t iq3xxs_grid[256] = { }; static const uint32_t iq3xs_grid[512] = { - 0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14, - 0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414, - 0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24, - 0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c, - 0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c, - 0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34, - 0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c, - 0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414, - 0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c, - 0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404, - 0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434, - 0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c, - 0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404, - 0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414, - 0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414, - 0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404, - 0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c, - 0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c, - 0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404, - 0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e, - 0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14, - 0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c, - 0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424, - 0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c, - 0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c, - 0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e, - 0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e, - 0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e, - 0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424, - 0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e, - 0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424, - 0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404, - 0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c, - 0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e, - 0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c, - 0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c, - 0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c, - 0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404, - 0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04, - 0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c, - 0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414, - 0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c, - 0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c, - 0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424, - 0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c, - 0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c, - 0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414, - 0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c, - 0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e, - 0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04, - 0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424, - 0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14, - 0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34, - 0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c, - 0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434, - 0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c, - 0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424, - 0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24, - 0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24, - 0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e, - 0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c, - 0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c, - 0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c, - 0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404, +0x04040404, 0x04142c14, 0x042c2424, 0x0404143c, 0x04140c0c, 0x042c0424, 0x04043434, 0x041c240c, +0x04341c1c, 0x040c0c34, 0x041c0404, 0x0434341c, 0x040c2c2c, 0x04241c04, 0x043c1414, 0x0414042c, +0x04243c04, 0x04042c14, 0x04142424, 0x042c143c, 0x04040c0c, 0x0c1c0424, 0x0c2c3434, 0x0c04240c, +0x0c1c141c, 0x0c340c34, 0x0c0c0404, 0x0c24341c, 0x0c34242c, 0x0c0c1c04, 0x0c240c14, 0x0c3c042c, +0x0c143404, 0x0c2c2c14, 0x14041c24, 0x1414143c, 0x142c040c, 0x14043c24, 0x141c2c34, 0x1434240c, +0x140c141c, 0x141c0c34, 0x14340404, 0x140c341c, 0x1424242c, 0x143c1c04, 0x14140c14, 0x1424042c, +0x1c043404, 0x1c142414, 0x1c2c1c24, 0x1c040c3c, 0x1c1c040c, 0x1c2c3424, 0x1c042c34, 0x1c1c1c0c, +0x1c34141c, 0x1c0c0434, 0x1c243c04, 0x1c342c1c, 0x1c0c242c, 0x1c241404, 0x243c0c14, 0x2414042c, +0x242c3404, 0x24042414, 0x24141c24, 0x242c0c3c, 0x2404040c, 0x241c3424, 0x24342434, 0x24041c0c, +0x241c0c1c, 0x24340434, 0x240c3404, 0x2c242c1c, 0x2c3c1c2c, 0x2c141404, 0x2c240414, 0x2c043c2c, +0x2c142c04, 0x2c2c2414, 0x2c041424, 0x2c1c0c3c, 0x2c2c040c, 0x2c043424, 0x2c1c2434, 0x2c341c0c, +0x2c0c0c1c, 0x34240434, 0x34343404, 0x340c2c1c, 0x34241c2c, 0x343c1404, 0x34140414, 0x342c342c, +0x34042c04, 0x34141c14, 0x342c1424, 0x3404043c, 0x341c3c0c, 0x34342c24, 0x3c042434, 0x3c1c140c, +0x3c340c1c, 0x3c0c0434, 0x3c243404, 0x3c3c241c, 0x3c0c1c2c, 0x04240c04, 0x04040414, 0x0414342c, +0x042c2c04, 0x04041c14, 0x041c1424, 0x042c043c, 0x04043c0c, 0x041c2c24, 0x04341c34, 0x040c140c, +0x0424041c, 0x04343c34, 0x040c2c04, 0x0424241c, 0x043c142c, 0x04140c04, 0x042c0414, 0x0404342c, +0x04142404, 0x042c1c14, 0x0c040c24, 0x0c1c043c, 0x0c34340c, 0x0c042c24, 0x0c1c1c34, 0x0c34140c, +0x0c0c041c, 0x0c243c34, 0x0c3c2c04, 0x0c0c241c, 0x0c24142c, 0x0c040404, 0x0c143c14, 0x142c2c2c, +0x14042404, 0x14141414, 0x142c0c24, 0x1404043c, 0x141c340c, 0x14342424, 0x140c1c34, 0x14240c0c, +0x1434041c, 0x140c3434, 0x14242c04, 0x143c1c1c, 0x1414142c, 0x1c2c0404, 0x1c043c14, 0x1c142c2c, +0x1c2c2404, 0x1c041414, 0x1c1c0c24, 0x1c343c3c, 0x1c042c0c, 0x1c1c2424, 0x1c341434, 0x1c0c0c0c, +0x1c24041c, 0x1c3c3434, 0x240c2404, 0x24241c1c, 0x24040c2c, 0x24140404, 0x242c3414, 0x24042c2c, +0x24141c04, 0x242c1414, 0x24040424, 0x241c3c3c, 0x24342c0c, 0x240c2424, 0x241c1434, 0x24340c0c, +0x2c0c041c, 0x2c243434, 0x2c3c2404, 0x2c14141c, 0x2c2c0c2c, 0x2c040404, 0x2c143414, 0x2c2c242c, +0x2c041c04, 0x2c1c0c14, 0x2c340424, 0x2c04343c, 0x2c1c2c0c, 0x2c341c24, 0x340c1434, 0x3424040c, +0x343c3c1c, 0x340c2c34, 0x34242404, 0x3404141c, 0x34140c2c, 0x342c0404, 0x34043414, 0x3414242c, +0x342c1c04, 0x34040c14, 0x341c0424, 0x3c34343c, 0x3c0c240c, 0x3c1c1c24, 0x3c340c34, 0x3c0c040c, +0x3c24341c, 0x3c3c2c34, 0x04141c04, 0x0424141c, 0x0404042c, 0x04143c04, 0x042c2c14, 0x0404242c, +0x041c1404, 0x04340c14, 0x04040424, 0x041c343c, 0x0434240c, 0x040c1c24, 0x04240c34, 0x043c040c, +0x040c341c, 0x04242434, 0x04041c04, 0x04140c1c, 0x042c042c, 0x04043404, 0x0c142c14, 0x0c2c1c2c, +0x0c041404, 0x0c1c0414, 0x0c343c24, 0x0c0c2c3c, 0x0c1c240c, 0x0c341424, 0x0c0c0c34, 0x0c24040c, +0x0c3c341c, 0x0c142434, 0x0c241c04, 0x0c040c1c, 0x1414042c, 0x142c3404, 0x14042c14, 0x141c1c2c, +0x142c1404, 0x14040414, 0x141c3424, 0x14342c3c, 0x140c1c0c, 0x14241424, 0x143c0434, 0x140c3c0c, +0x14242c1c, 0x1c042434, 0x1c141404, 0x1c2c0c1c, 0x1c04042c, 0x1c143404, 0x1c2c2414, 0x1c041c2c, +0x1c1c0c04, 0x1c340414, 0x1c0c3424, 0x1c1c2c3c, 0x1c341c0c, 0x1c0c1424, 0x1c240434, 0x243c3c0c, +0x24142c1c, 0x24241c34, 0x24041404, 0x2414041c, 0x242c3c2c, 0x24042c04, 0x241c2414, 0x242c142c, +0x24040c04, 0x241c0414, 0x24343424, 0x240c243c, 0x24241c0c, 0x2c340c24, 0x2c0c0434, 0x2c24340c, +0x2c3c2c1c, 0x2c141c34, 0x2c2c1404, 0x2c04041c, 0x2c143c2c, 0x2c2c2c04, 0x2c042414, 0x2c1c142c, +0x2c340404, 0x2c0c3c14, 0x341c2c24, 0x3434243c, 0x340c140c, 0x34240c24, 0x343c0434, 0x3414340c, +0x3424241c, 0x34041c34, 0x34140c04, 0x342c041c, 0x3404342c, 0x341c2c04, 0x342c1c14, 0x3404142c, +0x3c1c0404, 0x3c343c14, 0x3c0c2c24, 0x3c24243c, 0x3c34140c, 0x3c0c0c24, 0x3c243c34, 0x043c2c0c, +0x0414241c, 0x042c1434, 0x04040c04, 0x0414041c, 0x042c342c, 0x04042404, 0x041c1c14, 0x04340c2c, +0x040c0404, 0x041c3414, 0x04342c24, 0x040c1c3c, 0x0424140c, 0x043c0424, 0x04143c34, 0x04242c0c, +0x0404241c, 0x04141434, 0x042c0c04, 0x0c04041c, 0x0c1c342c, 0x0c2c2404, 0x0c041414, 0x0c1c0c2c, +0x0c340404, 0x0c0c3414, 0x0c242424, 0x0c341c3c, 0x0c0c0c0c, 0x0c240424, 0x0c3c3434, 0x0c142c0c, +0x0c2c1c1c, 0x14041434, 0x14140404, 0x142c3c1c, 0x14042c2c, 0x141c2404, 0x14341414, 0x14040c2c, +0x141c0404, 0x14343414, 0x140c2424, 0x14241c3c, 0x143c0c0c, 0x14140424, 0x1c243434, 0x1c04240c, +0x1c141c1c, 0x1c2c0c34, 0x1c040404, 0x1c1c341c, 0x1c2c2c2c, 0x1c041c04, 0x1c1c1414, 0x1c34042c, +0x1c0c3c04, 0x1c242c14, 0x1c342424, 0x1c0c143c, 0x24240c0c, 0x243c0424, 0x24143434, 0x242c240c, +0x24041c1c, 0x24140c34, 0x242c0404, 0x2404341c, 0x241c242c, 0x24341c04, 0x24040c14, 0x241c042c, +0x24343404, 0x2c0c2c14, 0x2c241c24, 0x2c3c143c, 0x2c0c040c, 0x2c243c24, 0x2c042c34, 0x2c14240c, +0x2c2c141c, 0x2c040c34, 0x2c1c0404, 0x2c2c341c, 0x2c04242c, 0x2c1c1c04, 0x2c340c14, 0x340c042c, +0x34243404, 0x34342c14, 0x340c1c24, 0x34240c3c, 0x343c040c, 0x34143424, 0x342c2c34, 0x34041c0c, +0x3414141c, 0x342c0434, 0x34043c04, 0x341c2c1c, 0x3434242c, 0x3c041404, 0x3c1c0c14, 0x3c34042c, +0x3c0c3404, 0x3c242414, 0x3c3c1c24, 0x040c0c3c, 0x0424040c, 0x04043424, 0x04142c34, 0x042c1c0c, +0x0404141c, 0x04140434, 0x042c3c04, 0x04042c1c, 0x041c1c2c, 0x04341404, 0x040c0414, 0x041c3c2c, +0x04342c04, 0x040c2414, 0x04241424, 0x043c0c3c, 0x0414040c, 0x042c3424, 0x04042434, 0x04141c0c, +0x0c2c0c1c, 0x0c040434, 0x0c1c3404, 0x0c342c1c, 0x0c041c2c, 0x0c1c1404, 0x0c340414, 0x0c0c3c2c, +0x0c242c04, 0x0c3c2414, 0x0c0c1424, 0x0c24043c, 0x0c043c0c, 0x14142c24, 0x142c2434, 0x1404140c, +0x14140c1c, 0x142c0434, 0x14043404, 0x141c241c, 0x14341c2c, 0x140c0c04, 0x141c0414, 0x1434342c, +0x140c2c04, 0x14241c14, 0x143c1424, 0x1c14043c, 0x1c243c0c, 0x1c042c24, 0x1c142434, 0x1c2c140c, +0x1c040c1c, 0x1c1c3c34, 0x1c342c04, 0x1c04241c, 0x1c1c142c, 0x1c340c04, 0x1c0c0414, 0x1c24342c, +0x1c3c2404, 0x240c1c14, 0x24240c24, 0x2404043c, 0x2414340c, 0x242c2c24, 0x24041c34, 0x2414140c, +0x242c041c, 0x24043c34, 0x241c2c04, 0x2434241c, 0x240c142c, 0x241c0c04, 0x2c340414, 0x2c0c342c, }; #define NGRID_IQ2XXS 512 @@ -4121,10 +4121,19 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y // ====================== 3.3125 bpw (de)-quantization +//#define IQ3S_MULTIPLIER 2469109 +//#define IQ3S_MULTIPLIER 746226 +//#define IQ3S_MULTIPLIER 717154 +#define IQ3S_MULTIPLIER 677595 +#define IQ3S_BITS 3 + void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; + uint32_t aux32[2]; + const int8_t * grid = (const int8_t *)aux32; + for (int i = 0; i < nb; i++) { const float d = GGML_FP16_TO_FP32(x[i].d); @@ -4133,25 +4142,25 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in const uint8_t * signs = x[i].signs; for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const float db1 = d * (0.5f + (x[i].scales[ib32/2] & 0xf)) * 0.5f; - const float db2 = d * (0.5f + (x[i].scales[ib32/2] >> 4)) * 0.5f; + const float db1 = d * (1 + 2*(x[i].scales[ib32/2] & 0xf)); + const float db2 = d * (1 + 2*(x[i].scales[ib32/2] >> 4)); for (int l = 0; l < 4; ++l) { - const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))); - for (int j = 0; j < 4; ++j) { - y[j+0] = db1 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f); - y[j+4] = db1 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f); + aux32[0] = ((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; + aux32[1] = ((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; + for (int j = 0; j < 8; ++j) { + //y[j] = db1 * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); + y[j] = db1 * (2*(((grid[j]+1)/2) & 7) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); } y += 8; } qs += 8; signs += 4; for (int l = 0; l < 4; ++l) { - const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256))); - for (int j = 0; j < 4; ++j) { - y[j+0] = db2 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f); - y[j+4] = db2 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f); + aux32[0] = ((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; + aux32[1] = ((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; + for (int j = 0; j < 8; ++j) { + //y[j] = db2 * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); + y[j] = db2 * (2*(((grid[j]+1)/2) & 7) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); } y += 8; } @@ -4159,6 +4168,34 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in qs += 8; signs += 4; } + + //for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + // const float db1 = d * (0.5f + (x[i].scales[ib32/2] & 0xf)) * 0.5f; + // const float db2 = d * (0.5f + (x[i].scales[ib32/2] >> 4)) * 0.5f; + // for (int l = 0; l < 4; ++l) { + // const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))); + // const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))); + // for (int j = 0; j < 4; ++j) { + // y[j+0] = db1 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f); + // y[j+4] = db1 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f); + // } + // y += 8; + // } + // qs += 8; + // signs += 4; + // for (int l = 0; l < 4; ++l) { + // const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256))); + // const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256))); + // for (int j = 0; j < 4; ++j) { + // y[j+0] = db2 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f); + // y[j+4] = db2 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f); + // } + // y += 8; + // } + // qh += 2; + // qs += 8; + // signs += 4; + //} } } @@ -11311,6 +11348,112 @@ static int iq3_compare_func(const void * left, const void * right) { return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0; } +static void iq3xs_init_grid512(void) { + + const int kmap_size = 1 << IQ3S_BITS*4; + const int grid_size = 512; + const int nwant = 3; + const int gindex = iq3_data_index(512); + const uint8_t kmask = (1 << IQ3S_BITS) - 1; + + uint32_t * kgrid_q3xs; + int * kmap_q3xs; + uint16_t * kneighbors_q3xs; + + printf("================================================================= %s(grid_size = %d, map_size = %d)\n", + __func__, grid_size, kmap_size); + uint32_t * the_grid = (uint32_t *)malloc(grid_size*sizeof(uint32_t)); + uint32_t aux32; + const uint8_t * q4 = (const uint8_t *)&aux32; + for (int k = 0; k < grid_size; ++k) { + int8_t * pos = (int8_t *)(the_grid + k); + aux32 = (IQ3S_MULTIPLIER * k) & 0x0f0f0f0f; + for (int i = 0; i < 4; ++i) { + //pos[i] = 2*((q4[i]-1)/2) + 1; + pos[i] = 2*(((q4[i]+1)/2) & kmask) + 1; + } + } + + kgrid_q3xs = the_grid; + iq3_data[gindex].grid = the_grid; + kmap_q3xs = (int *)malloc(kmap_size*sizeof(int)); + iq3_data[gindex].map = kmap_q3xs; + for (int i = 0; i < kmap_size; ++i) kmap_q3xs[i] = -1; + for (int i = 0; i < grid_size; ++i) { + aux32 = kgrid_q3xs[i]; + uint16_t index = 0; + for (int k=0; k<4; ++k) { + uint16_t q = (q4[k] - 1)/2; + index |= (q << IQ3S_BITS*k); + } + kmap_q3xs[index] = i; + } + int8_t pos[4]; + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + int num_neighbors = 0, num_not_in_map = 0; + for (int i = 0; i < kmap_size; ++i) { + if (kmap_q3xs[i] >= 0) continue; + ++num_not_in_map; + for (int k = 0; k < 4; ++k) { + int l = (i >> IQ3S_BITS*k) & kmask; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); + int d2 = 0; + for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); + int n = 0; int d2 = dist2[0]; + int nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + ++n; + } + num_neighbors += n; + } + printf("%s: %d neighbours in total\n", __func__, num_neighbors); + kneighbors_q3xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t)); + iq3_data[gindex].neighbours = kneighbors_q3xs; + int counter = 0; + for (int i = 0; i < kmap_size; ++i) { + if (kmap_q3xs[i] >= 0) continue; + for (int k = 0; k < 4; ++k) { + int l = (i >> IQ3S_BITS*k) & kmask; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); + int d2 = 0; + for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); + kmap_q3xs[i] = -(counter + 1); + int d2 = dist2[0]; + uint16_t * start = &kneighbors_q3xs[counter++]; + int n = 0, nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + kneighbors_q3xs[counter++] = dist2[2*j+1]; + ++n; + } + *start = n; + } + free(dist2); +} + void iq3xs_init_impl(int grid_size) { const int gindex = iq3_data_index(grid_size); if (iq3_data[gindex].grid) { @@ -11334,44 +11477,49 @@ void iq3xs_init_impl(int grid_size) { 3185, 3215, 3252, 3288, 3294, 3364, 3397, 3434, 3483, 3523, 3537, 3587, 3589, 3591, 3592, 3610, 3626, 3670, 3680, 3722, 3749, 3754, 3776, 3789, 3803, 3824, 3857, 3873, 3904, 3906, 3924, 3992, }; - static const uint16_t kgrid_512[512] = { - 0, 1, 2, 5, 7, 8, 9, 10, 12, 14, 16, 17, 21, 27, 32, 34, - 37, 39, 41, 43, 48, 50, 57, 60, 63, 64, 65, 66, 68, 72, 73, 77, - 80, 83, 87, 89, 93, 100, 113, 117, 122, 128, 129, 133, 135, 136, 139, 142, - 145, 149, 152, 156, 162, 165, 167, 169, 171, 184, 187, 195, 201, 205, 208, 210, - 217, 219, 222, 228, 232, 234, 247, 249, 253, 256, 267, 271, 273, 276, 282, 288, - 291, 297, 312, 322, 324, 336, 338, 342, 347, 353, 357, 359, 374, 379, 390, 393, - 395, 409, 426, 441, 448, 450, 452, 464, 466, 470, 475, 488, 492, 512, 513, 514, - 516, 520, 521, 523, 525, 527, 528, 530, 537, 540, 542, 556, 558, 561, 570, 576, - 577, 579, 582, 584, 588, 593, 600, 603, 609, 616, 618, 632, 638, 640, 650, 653, - 655, 656, 660, 666, 672, 675, 685, 688, 698, 705, 708, 711, 712, 715, 721, 727, - 728, 732, 737, 754, 760, 771, 773, 778, 780, 793, 795, 802, 806, 808, 812, 833, - 840, 843, 849, 856, 858, 873, 912, 916, 919, 932, 934, 961, 963, 968, 970, 977, - 989, 993, 1010, 1016, 1024, 1025, 1027, 1029, 1031, 1032, 1034, 1036, 1038, 1041, 1043, 1047, - 1048, 1050, 1057, 1059, 1061, 1064, 1066, 1079, 1080, 1083, 1085, 1088, 1090, 1096, 1099, 1103, - 1106, 1109, 1113, 1116, 1122, 1129, 1153, 1156, 1159, 1169, 1171, 1176, 1183, 1185, 1195, 1199, - 1209, 1212, 1216, 1218, 1221, 1225, 1234, 1236, 1241, 1243, 1250, 1256, 1270, 1281, 1287, 1296, - 1299, 1306, 1309, 1313, 1338, 1341, 1348, 1353, 1362, 1375, 1376, 1387, 1400, 1408, 1410, 1415, - 1425, 1453, 1457, 1477, 1481, 1494, 1496, 1507, 1512, 1538, 1545, 1547, 1549, 1551, 1554, 1561, - 1563, 1565, 1570, 1572, 1575, 1577, 1587, 1593, 1601, 1603, 1605, 1612, 1617, 1619, 1632, 1648, - 1658, 1662, 1664, 1674, 1680, 1690, 1692, 1704, 1729, 1736, 1740, 1745, 1747, 1751, 1752, 1761, - 1763, 1767, 1773, 1787, 1795, 1801, 1806, 1810, 1817, 1834, 1840, 1844, 1857, 1864, 1866, 1877, - 1882, 1892, 1902, 1915, 1934, 1953, 1985, 1987, 2000, 2002, 2013, 2048, 2052, 2058, 2064, 2068, - 2071, 2074, 2081, 2088, 2104, 2114, 2119, 2121, 2123, 2130, 2136, 2141, 2147, 2153, 2157, 2177, - 2179, 2184, 2189, 2193, 2203, 2208, 2223, 2226, 2232, 2244, 2249, 2251, 2256, 2258, 2265, 2269, - 2304, 2306, 2324, 2335, 2336, 2361, 2373, 2375, 2385, 2418, 2443, 2460, 2480, 2504, 2509, 2520, - 2531, 2537, 2562, 2568, 2572, 2578, 2592, 2596, 2599, 2602, 2614, 2620, 2625, 2627, 2629, 2634, - 2641, 2650, 2682, 2688, 2697, 2707, 2712, 2718, 2731, 2754, 2759, 2760, 2775, 2788, 2793, 2805, - 2811, 2817, 2820, 2832, 2842, 2854, 2890, 2902, 2921, 2923, 2978, 3010, 3012, 3026, 3081, 3083, - 3085, 3097, 3099, 3120, 3136, 3152, 3159, 3188, 3210, 3228, 3234, 3245, 3250, 3256, 3264, 3276, - 3281, 3296, 3349, 3363, 3378, 3392, 3395, 3420, 3440, 3461, 3488, 3529, 3531, 3584, 3588, 3591, - 3600, 3602, 3614, 3616, 3628, 3634, 3650, 3657, 3668, 3683, 3685, 3713, 3716, 3720, 3726, 3729, - 3736, 3753, 3778, 3802, 3805, 3819, 3841, 3845, 3851, 3856, 3880, 3922, 3938, 3970, 3993, 4032, - }; + if (grid_size == 512) { + iq3xs_init_grid512(); + return; + } +// static const uint16_t kgrid_512[512] = { +// 0, 170, 356, 23, 137, 324, 54, 225, 411, 78, 192, 435, 109, 280, 466, 133, +// 312, 42, 164, 343, 9, 708, 886, 545, 723, 910, 576, 819, 933, 600, 778, 965, +// 688, 874, 1052, 1175, 1345, 1084, 1262, 1441, 1107, 1230, 1408, 1139, 1317, 1496, 1162, 1285, +//1584, 1698, 1884, 1551, 1729, 1908, 1582, 1753, 1939, 1606, 1848, 1963, 1637, 1808, 2506, 2181, +//2416, 2082, 2204, 2383, 2049, 2292, 2470, 2073, 2251, 2438, 2160, 2859, 3037, 2704, 2818, 2621, +//2728, 2914, 2580, 2767, 2881, 2612, 2790, 2969, 2635, 3334, 3504, 3179, 3357, 3536, 3202, 3445, +//3112, 3226, 3412, 3079, 3321, 3500, 3622, 3793, 3979, 3654, 3888, 4067, 3677, 264, 2, 181, +// 360, 26, 212, 327, 57, 236, 414, 81, 259, 446, 104, 291, 469, 136, 322, 53, +// 160, 346, 524, 711, 945, 556, 734, 913, 579, 830, 1000, 611, 789, 512, 698, 1389, +//1056, 1170, 1356, 1031, 1265, 1444, 1118, 1289, 1411, 1142, 1320, 1499, 1173, 1856, 1594, 1709, +//1888, 1554, 1740, 1983, 1577, 1764, 1942, 1609, 1795, 2038, 2144, 2331, 2061, 2176, 2418, 2093, +//2200, 2386, 2052, 2303, 2473, 2148, 2262, 2441, 2627, 2870, 3040, 2707, 2893, 2560, 2738, 2917, +//2584, 2762, 2948, 2615, 2793, 2972, 3158, 3329, 3579, 3182, 3360, 3091, 3213, 3392, 3122, 3237, +//3416, 3082, 3268, 4023, 3681, 3804, 3982, 3649, 3891, 4078, 152, 275, 5, 184, 362, 37, +// 208, 394, 4, 247, 417, 92, 270, 449, 115, 294, 24, 139, 325, 48, 682, 861, +// 528, 706, 956, 623, 737, 916, 590, 769, 1011, 678, 792, 523, 1157, 1392, 1066, 1245, +//1360, 1026, 1268, 1455, 1113, 1300, 1478, 1145, 1323, 1574, 1680, 1867, 1541, 1712, 1890, 1565, +//1736, 1922, 1652, 1775, 1945, 1620, 1798, 2553, 2219, 2334, 2064, 2179, 2429, 2088, 2274, 2389, +//2056, 2242, 2484, 2151, 2329, 2956, 2630, 2865, 3051, 2718, 2896, 2563, 2749, 2920, 2594, 2773, +//2944, 2682, 3308, 3495, 3153, 3340, 3526, 3249, 3363, 3102, 3208, 3395, 3125, 3304, 3418, 3093, +//3776, 4026, 3692, 3879, 3985, 3660, 3902, 489, 163, 342, 8, 131, 373, 32, 218, 397, +// 64, 242, 428, 95, 273, 452, 190, 297, 35, 150, 328, 515, 757, 864, 530, 717, +// 896, 626, 804, 927, 585, 772, 1014, 681, 859, 1046, 1152, 1403, 1069, 1248, 1426, 1037, +//1216, 1458, 1124, 1311, 1481, 1156, 1846, 1569, 1691, 1870, 1536, 1779, 1901, 1560, 1746, 1925, +//1656, 1834, 1956, 1623, 2313, 2500, 2230, 2401, 2075, 2190, 2368, 2099, 2277, 2456, 2058, 2245, +//2480, 2666, 2844, 3031, 2625, 2876, 2606, 2721, 2899, 2574, 2752, 2931, 2597, 2776, 2954, 3141, +//3376, 3498, 3164, 3343, 3521, 3252, 3438, 3097, 3219, 3398, 3128, 3307, 3493, 3600, 3786, 3973, +//3696, 3874, 4060, 79, 257, 52, 174, 345, 19, 134, 376, 43, 221, 400, 66, 253, +// 424, 98, 276, 463, 129, 372, 38, 153, 843, 518, 752, 939, 541, 720, 898, 637, +// 808, 994, 596, 775, 569, 1196, 1382, 1041, 1163, 1350, 1072, 1251, 1437, 1096, 1218, 1461, +//1128, 1306, 1492, 1671, 1849, 1580, 1702, 1873, 1547, 1790, 1960, 1571, 1749, 1928, 1602, 1845, +//2016, 2138, 2316, 2055, 2225, 2412, 2078, 2193, 2371, 2110, 2280, 2467, 2133, 2248, 2946, 2677, +// }; const int kmap_size = 4096; const int nwant = grid_size == 256 ? 2 : 3; - const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512; + //const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512; + const uint16_t * kgrid = kgrid_256; uint32_t * kgrid_q3xs; int * kmap_q3xs; uint16_t * kneighbors_q3xs; @@ -11773,6 +11921,8 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo for (int ibl = 0; ibl < nbl; ++ibl) { + //float block_mse = 0; + memset(&y[ibl], 0, sizeof(block_iq3_s)); y[ibl].d = GGML_FP32_TO_FP16(0.f); @@ -11823,7 +11973,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l)); } uint16_t u = 0; - for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i); + for (int i = 0; i < 4; ++i) u |= Laux[4*k+i] << IQ3S_BITS*i; int grid_index = kmap_q3xs[u]; is_on_grid_aux[k] = true; if (grid_index < 0) { @@ -11855,7 +12005,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo for (int i = 0; i < 4; ++i) { int l = nearest_int(0.5f*(id*xval[4*k+i]-1)); l = MAX(0, MIN(kMaxQ-1, l)); - u |= (l << 3*i); + u |= l << IQ3S_BITS*i; } int grid_index = kmap_q3xs[u]; if (grid_index < 0) { @@ -11882,7 +12032,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo } for (int k = 0; k < bs4; ++k) { uint16_t u = 0; - for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i); + for (int i = 0; i < 4; ++i) u |= L[4*k+i] << IQ3S_BITS*i; int grid_index = kmap_q3xs[u]; if (grid_index < 0) { printf("Oops: found point %u not on grid:", u); @@ -11899,14 +12049,25 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo GGML_ASSERT(scale >= 0); scales[ib] = scale; max_scale = MAX(max_scale, scale); + + //for (int k = 0; k < bs8; ++k) { + // for (int i = 0; i < 8; ++i) { + // float diff = scale*(2*L[8*k+i] + 1) * (block_signs[k] & (1 << i) ? -1 : 1) - xb[8*k+i]; + // block_mse += diff*diff; + // } + //} + } + //printf("Block %d: rmse = %g\n", ibl, (double)sqrtf(block_mse/QK_K)); + if (!max_scale) { continue; } float d = max_scale/31; - y[ibl].d = GGML_FP32_TO_FP16(d); + //y[ibl].d = GGML_FP32_TO_FP16(d * 1.025f); //1.02f); //1.0125f); + y[ibl].d = GGML_FP32_TO_FP16(d * 1.033f); //1.04f); //1.02f); //1.0125f); float id = 1/d; for (int ib = 0; ib < QK_K/block_size; ib += 2) { int l1 = nearest_int(0.5f*(id*scales[ib+0]-1)); From 1cc7cb2b4639b1eca4a91a25acd88c193db71ec7 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 1 Mar 2024 12:02:39 +0200 Subject: [PATCH 02/24] iq3_s(multiplier): use SIMD also in dequantize --- ggml-cuda.cu | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index a58214557..4f88a1ae1 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2389,24 +2389,28 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_ dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint8_t * qs = x[i].qs + 8*ib; int32_t aux32[2]; - const uint8_t * grid = (const uint8_t *)aux32; + const int8_t * grid = (const int8_t *)aux32; const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)); const uint8_t signs = x[i].signs[4*ib + il]; aux32[0] = ((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; aux32[1] = ((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + aux32[0] = __vadd4(((__vadd4(aux32[0], 0x01010101) >> 1) & 0x07070707) << 1, 0x01010101); + aux32[1] = __vadd4(((__vadd4(aux32[1], 0x01010101) >> 1) & 0x07070707) << 1, 0x01010101); + uint32_t signs0 = __vcmpeq4(((signs & 0xf) * 0x01010101) & 0x08040201, 0x08040201); + uint32_t signs1 = __vcmpeq4(((signs >> 4) * 0x01010101) & 0x08040201, 0x08040201); + aux32[0] = __vsub4(aux32[0] ^ signs0, signs0); + aux32[1] = __vsub4(aux32[1] ^ signs1, signs1); + for (int j = 0; j < 8; ++j) { + y[j] = d * grid[j]; + } +#else for (int j = 0; j < 8; ++j) { //y[j] = d * (2*((grid[j]-1)/2) + 1) * (signs & kmask_iq2xs[j] ? -1.f : 1.f); //y[j] = d * iq3s_values[grid[j]] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); y[j] = d * (2*(((grid[j]+1)/2) & 7) + 1) * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } -// const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256))); -// const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256))); -// const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)) * 0.5f; -// const uint8_t signs = x[i].signs[4*ib + il]; -// for (int j = 0; j < 4; ++j) { -// y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); -// y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); -// } +#endif #else assert(false); #endif From 4c21c826e1a6489c27080eddf8c82a9d997e607e Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 1 Mar 2024 13:28:20 +0200 Subject: [PATCH 03/24] WIP --- ggml-cuda.cu | 10 ++++------ ggml-quants.c | 46 +++++++++++++++++++++++++++------------------- 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 4f88a1ae1..328fc01ed 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2370,11 +2370,9 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds } -//#define IQ3S_MULTIPLIER 2469109 //#define IQ3S_MULTIPLIER 746226 //#define IQ3S_MULTIPLIER 717154 #define IQ3S_MULTIPLIER 677595 -//static const __device__ uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15}; template static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { @@ -2395,8 +2393,8 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_ aux32[0] = ((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; aux32[1] = ((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - aux32[0] = __vadd4(((__vadd4(aux32[0], 0x01010101) >> 1) & 0x07070707) << 1, 0x01010101); - aux32[1] = __vadd4(((__vadd4(aux32[1], 0x01010101) >> 1) & 0x07070707) << 1, 0x01010101); + aux32[0] = (((__vadd4(aux32[0], 0x01010101) >> 1) & 0x07070707) << 1) | 0x01010101; + aux32[1] = (((__vadd4(aux32[1], 0x01010101) >> 1) & 0x07070707) << 1) | 0x01010101; uint32_t signs0 = __vcmpeq4(((signs & 0xf) * 0x01010101) & 0x08040201, 0x08040201); uint32_t signs1 = __vcmpeq4(((signs >> 4) * 0x01010101) & 0x08040201, 0x08040201); aux32[0] = __vsub4(aux32[0] ^ signs0, signs0); @@ -5227,8 +5225,8 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( for (int l = 0; l < 4; ++l) { aux32[0] = ((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; aux32[1] = ((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; - aux32[0] = __vadd4(((__vadd4(aux32[0], 0x01010101) >> 1) & 0x07070707) << 1, 0x01010101); - aux32[1] = __vadd4(((__vadd4(aux32[1], 0x01010101) >> 1) & 0x07070707) << 1, 0x01010101); + aux32[0] = (((__vadd4(aux32[0], 0x01010101) >> 1) & 0x07070707) << 1) | 0x01010101; + aux32[1] = (((__vadd4(aux32[1], 0x01010101) >> 1) & 0x07070707) << 1) | 0x01010101; uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201); uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201); const int grid_l = __vsub4(aux32[0] ^ signs0, signs0); diff --git a/ggml-quants.c b/ggml-quants.c index 9f9d299fe..5e98291bc 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -10019,6 +10019,15 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void #endif } +#ifdef __AVX2__ +static inline __m256i shift_left_epi16(__m256i a, __m256i count) { + const __m256i mask = _mm256_set1_epi32(0xffff0000); + const __m256i lo_half = _mm256_sllv_epi32(a, _mm256_andnot_si256(mask, count)); + const __m256i hi_half = _mm256_sllv_epi32(_mm256_and_si256(mask, a), _mm256_srli_epi32(count, 16)); + return _mm256_blend_epi16(lo_half, hi_half, 0xaa); +} +#endif + void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -10109,6 +10118,13 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1); const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2); + const __m256i idx_mask = _mm256_set1_epi16(256); + const __m256i idx_shift = _mm256_set_epi16(8, 7, 6, 5, 4, 3, 2, 1, 8, 7, 6, 5, 4, 3, 2, 1); + const __m256i idx_mult = _mm256_set1_epi32(IQ3S_MULTIPLIER); + const __m256i m1 = _mm256_set1_epi32(0x01010101); + const __m256i m7 = _mm256_set1_epi32(0x07070707); + const __m256i m15 = _mm256_set1_epi32(0x0f0f0f0f); + __m256 accumf = _mm256_setzero_ps(); for (int i = 0; i < nb; ++i) { const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; @@ -10121,24 +10137,16 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q2_1 = _mm256_set_epi32(iq3xs_grid[qs[7] | ((qh[ib32+0] << 1) & 256)], - iq3xs_grid[qs[6] | ((qh[ib32+0] << 2) & 256)], - iq3xs_grid[qs[5] | ((qh[ib32+0] << 3) & 256)], - iq3xs_grid[qs[4] | ((qh[ib32+0] << 4) & 256)], - iq3xs_grid[qs[3] | ((qh[ib32+0] << 5) & 256)], - iq3xs_grid[qs[2] | ((qh[ib32+0] << 6) & 256)], - iq3xs_grid[qs[1] | ((qh[ib32+0] << 7) & 256)], - iq3xs_grid[qs[0] | ((qh[ib32+0] << 8) & 256)]); - qs += 8; - const __m256i q2_2 = _mm256_set_epi32(iq3xs_grid[qs[7] | ((qh[ib32+1] << 1) & 256)], - iq3xs_grid[qs[6] | ((qh[ib32+1] << 2) & 256)], - iq3xs_grid[qs[5] | ((qh[ib32+1] << 3) & 256)], - iq3xs_grid[qs[4] | ((qh[ib32+1] << 4) & 256)], - iq3xs_grid[qs[3] | ((qh[ib32+1] << 5) & 256)], - iq3xs_grid[qs[2] | ((qh[ib32+1] << 6) & 256)], - iq3xs_grid[qs[1] | ((qh[ib32+1] << 7) & 256)], - iq3xs_grid[qs[0] | ((qh[ib32+1] << 8) & 256)]); - qs += 8; + const __m128i idx_l_8 = _mm_loadu_si128((const __m128i*)qs); qs += 16; + const __m256i idx_l_16 = _mm256_cvtepu8_epi16(idx_l_8); + const __m256i idx_h_16 = _mm256_set_m128i(_mm_set1_epi16(qh[ib32+1]), _mm_set1_epi16(qh[ib32+0])); + const __m256i idx_16 = _mm256_or_si256(idx_l_16, _mm256_and_si256(shift_left_epi16(idx_h_16, idx_shift), idx_mask)); + const __m256i idx_32_l = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_16)); + const __m256i idx_32_h = _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_16, 1)); + const __m256i idx_l = _mm256_add_epi32(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1); + const __m256i q2_1 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_l, 1), m7), 1), m1); + const __m256i idx_h = _mm256_add_epi32(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1); + const __m256i q2_2 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_h, 1), m7), 1), m1); __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16)); aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); @@ -10166,7 +10174,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v } - *s = 0.25f * hsum_float_8(accumf); + *s = hsum_float_8(accumf); #else From 160acecabab3fcca48e34ec838d82fb44269a7a2 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 1 Mar 2024 13:44:06 +0200 Subject: [PATCH 04/24] iq3_s_multiplier: CUDA and AVX2 works CUDA is 153.8 t/s, so faster than lookup table (151 t/s) and Q3_K_S (145 t/s). AVX2 on Ryzen-5975WX is 13.7 t/s, so faster than lookup (12.7 t/s), but slower than Q3_K_S (15.5 t/s). --- ggml-quants.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-quants.c b/ggml-quants.c index 5e98291bc..ac851f9ab 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -10119,7 +10119,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2); const __m256i idx_mask = _mm256_set1_epi16(256); - const __m256i idx_shift = _mm256_set_epi16(8, 7, 6, 5, 4, 3, 2, 1, 8, 7, 6, 5, 4, 3, 2, 1); + const __m256i idx_shift = _mm256_set_epi16(1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8); const __m256i idx_mult = _mm256_set1_epi32(IQ3S_MULTIPLIER); const __m256i m1 = _mm256_set1_epi32(0x01010101); const __m256i m7 = _mm256_set1_epi32(0x07070707); From e43e81a5d7908b9b369ffa0f4b59bfd90d094796 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 1 Mar 2024 18:48:08 +0200 Subject: [PATCH 05/24] WIP --- ggml-quants.c | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ggml-quants.c b/ggml-quants.c index ac851f9ab..78facff12 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -10118,8 +10118,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1); const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2); - const __m256i idx_mask = _mm256_set1_epi16(256); - const __m256i idx_shift = _mm256_set_epi16(1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8); + const __m256i idx_mask = _mm256_set1_epi32(256); + const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8); const __m256i idx_mult = _mm256_set1_epi32(IQ3S_MULTIPLIER); const __m256i m1 = _mm256_set1_epi32(0x01010101); const __m256i m7 = _mm256_set1_epi32(0x07070707); @@ -10139,10 +10139,10 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; const __m128i idx_l_8 = _mm_loadu_si128((const __m128i*)qs); qs += 16; const __m256i idx_l_16 = _mm256_cvtepu8_epi16(idx_l_8); - const __m256i idx_h_16 = _mm256_set_m128i(_mm_set1_epi16(qh[ib32+1]), _mm_set1_epi16(qh[ib32+0])); - const __m256i idx_16 = _mm256_or_si256(idx_l_16, _mm256_and_si256(shift_left_epi16(idx_h_16, idx_shift), idx_mask)); - const __m256i idx_32_l = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_16)); - const __m256i idx_32_h = _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_16, 1)); + const __m256i idx_h_l = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[ib32+0]), idx_shift), idx_mask); + const __m256i idx_h_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[ib32+1]), idx_shift), idx_mask); + const __m256i idx_32_l = _mm256_or_si256(idx_h_l, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l_16))); + const __m256i idx_32_h = _mm256_or_si256(idx_h_h, _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l_16, 1))); const __m256i idx_l = _mm256_add_epi32(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1); const __m256i q2_1 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_l, 1), m7), 1), m1); const __m256i idx_h = _mm256_add_epi32(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1); From 0fe9cd488f752dc4fe14238b181089b486724127 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 2 Mar 2024 17:56:16 +0200 Subject: [PATCH 06/24] WIP --- ggml-cuda.cu | 26 +++++++------------------- ggml-quants.c | 25 ++++++++++++++----------- 2 files changed, 21 insertions(+), 30 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 328fc01ed..6dec7f576 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2371,8 +2371,8 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds } //#define IQ3S_MULTIPLIER 746226 -//#define IQ3S_MULTIPLIER 717154 -#define IQ3S_MULTIPLIER 677595 +//#define IQ3S_MULTIPLIER 677595 +#define IQ3S_MULTIPLIER 190842953LL template static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { @@ -2393,8 +2393,8 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_ aux32[0] = ((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; aux32[1] = ((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - aux32[0] = (((__vadd4(aux32[0], 0x01010101) >> 1) & 0x07070707) << 1) | 0x01010101; - aux32[1] = (((__vadd4(aux32[1], 0x01010101) >> 1) & 0x07070707) << 1) | 0x01010101; + aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101; + aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101; uint32_t signs0 = __vcmpeq4(((signs & 0xf) * 0x01010101) & 0x08040201, 0x08040201); uint32_t signs1 = __vcmpeq4(((signs >> 4) * 0x01010101) & 0x08040201, 0x08040201); aux32[0] = __vsub4(aux32[0] ^ signs0, signs0); @@ -2404,9 +2404,7 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_ } #else for (int j = 0; j < 8; ++j) { - //y[j] = d * (2*((grid[j]-1)/2) + 1) * (signs & kmask_iq2xs[j] ? -1.f : 1.f); - //y[j] = d * iq3s_values[grid[j]] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); - y[j] = d * (2*(((grid[j]+1)/2) & 7) + 1) * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + y[j] = d * (2*((grid[j]-1)/2) + 1) * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } #endif #else @@ -5216,7 +5214,6 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( const block_iq3_s * bq2 = (const block_iq3_s *) vbq; uint32_t aux32[2]; - const uint8_t * grid = (const uint8_t *)aux32; const int ib32 = iqs; const uint8_t * qs = bq2->qs + 8*ib32; @@ -5225,25 +5222,16 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( for (int l = 0; l < 4; ++l) { aux32[0] = ((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; aux32[1] = ((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; - aux32[0] = (((__vadd4(aux32[0], 0x01010101) >> 1) & 0x07070707) << 1) | 0x01010101; - aux32[1] = (((__vadd4(aux32[1], 0x01010101) >> 1) & 0x07070707) << 1) | 0x01010101; + aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0) >> 1) & 0x07070707) << 1) | 0x01010101; + aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0) >> 1) & 0x07070707) << 1) | 0x01010101; uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201); uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201); const int grid_l = __vsub4(aux32[0] ^ signs0, signs0); const int grid_h = __vsub4(aux32[1] ^ signs1, signs1); sumi = __dp4a(grid_l, *((int *)q8+0), sumi); sumi = __dp4a(grid_h, *((int *)q8+1), sumi); - //const uint32_t * grid1 = iq3xs_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)); - //const uint32_t * grid2 = iq3xs_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)); - //uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201); - //uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201); - //const int grid_l = __vsub4(grid1[0] ^ signs0, signs0); - //const int grid_h = __vsub4(grid2[0] ^ signs1, signs1); - //sumi = __dp4a(grid_l, *((int *)q8+0), sumi); - //sumi = __dp4a(grid_h, *((int *)q8+1), sumi); q8 += 8; } - //const float d = (float)bq2->d * (0.5f + ((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds) * 0.5f; const float d = (float)bq2->d * (1 + 2*((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds); return d * sumi; #else diff --git a/ggml-quants.c b/ggml-quants.c index 78facff12..cf60c65cf 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -4124,7 +4124,8 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y //#define IQ3S_MULTIPLIER 2469109 //#define IQ3S_MULTIPLIER 746226 //#define IQ3S_MULTIPLIER 717154 -#define IQ3S_MULTIPLIER 677595 +//#define IQ3S_MULTIPLIER 677595 +#define IQ3S_MULTIPLIER 190842953 #define IQ3S_BITS 3 void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int k) { @@ -4148,8 +4149,7 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in aux32[0] = ((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; aux32[1] = ((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; for (int j = 0; j < 8; ++j) { - //y[j] = db1 * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); - y[j] = db1 * (2*(((grid[j]+1)/2) & 7) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); + y[j] = db1 * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); } y += 8; } @@ -4159,8 +4159,7 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in aux32[0] = ((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; aux32[1] = ((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; for (int j = 0; j < 8; ++j) { - //y[j] = db2 * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); - y[j] = db2 * (2*(((grid[j]+1)/2) & 7) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); + y[j] = db2 * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); } y += 8; } @@ -10121,9 +10120,10 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const __m256i idx_mask = _mm256_set1_epi32(256); const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8); const __m256i idx_mult = _mm256_set1_epi32(IQ3S_MULTIPLIER); - const __m256i m1 = _mm256_set1_epi32(0x01010101); + const __m256i m1 = _mm256_set1_epi8(1); const __m256i m7 = _mm256_set1_epi32(0x07070707); const __m256i m15 = _mm256_set1_epi32(0x0f0f0f0f); + const __m256i m0 = _mm256_setzero_si256(); __m256 accumf = _mm256_setzero_ps(); for (int i = 0; i < nb; ++i) { @@ -10143,9 +10143,13 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const __m256i idx_h_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[ib32+1]), idx_shift), idx_mask); const __m256i idx_32_l = _mm256_or_si256(idx_h_l, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l_16))); const __m256i idx_32_h = _mm256_or_si256(idx_h_h, _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l_16, 1))); - const __m256i idx_l = _mm256_add_epi32(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1); + + // v = MAX(((IQ3S_MULTIPLIER * idx) & 0x0f0f0f0f) - 1, 0) + const __m256i idx_l = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1), m0); + // v = (((v >> 1) & 0x07070707) << 1) | 0x01010101 const __m256i q2_1 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_l, 1), m7), 1), m1); - const __m256i idx_h = _mm256_add_epi32(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1); + + const __m256i idx_h = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1), m0); const __m256i q2_2 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_h, 1), m7), 1), m1); __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16)); @@ -11375,10 +11379,9 @@ static void iq3xs_init_grid512(void) { const uint8_t * q4 = (const uint8_t *)&aux32; for (int k = 0; k < grid_size; ++k) { int8_t * pos = (int8_t *)(the_grid + k); - aux32 = (IQ3S_MULTIPLIER * k) & 0x0f0f0f0f; + aux32 = ((uint64_t)IQ3S_MULTIPLIER * k) & 0x0f0f0f0f; for (int i = 0; i < 4; ++i) { - //pos[i] = 2*((q4[i]-1)/2) + 1; - pos[i] = 2*(((q4[i]+1)/2) & kmask) + 1; + pos[i] = 2*((q4[i]-1)/2) + 1; } } From bf90920fb2f4f9315201ffee09523df8a14ab977 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 2 Mar 2024 19:17:27 +0200 Subject: [PATCH 07/24] iq3_s_mult: ARM_NEON works - 13 t/s --- ggml-quants.c | 49 +++++++++++++++++++++++++++++++------------------ 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/ggml-quants.c b/ggml-quants.c index cf60c65cf..cb1c43461 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -10048,9 +10048,18 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; + static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1}; + const uint8x16x2_t mask1 = vld1q_u8_x2(k_mask1); const uint8x16_t mask2 = vld1q_u8(k_mask2); + const uint32x4_t idx_mult = vdupq_n_u32(IQ3S_MULTIPLIER); + const int16x8_t idx_shift = vld1q_s16(k_shift); + const uint16x8_t idx_mask1 = vdupq_n_u16(256); + const uint32x4_t idx_mask2 = vdupq_n_u32(0x0f0f0f0f); + const int8x16_t m1 = vdupq_n_s8(1); + const int8x16_t m0 = vdupq_n_s8(0); + uint8x16x2_t vs; ggml_int8x16x4_t q3s; ggml_int8x16x4_t q8b; @@ -10065,35 +10074,39 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v int sumi1 = 0, sumi2 = 0; for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { q8b = ggml_vld1q_s8_x4(q8); q8 += 64; - const uint32x4_t aux32x4_0 = {iq3xs_grid[qs[ 0] | ((qh[ib32+0] << 8) & 256)], iq3xs_grid[qs[ 1] | ((qh[ib32+0] << 7) & 256)], - iq3xs_grid[qs[ 2] | ((qh[ib32+0] << 6) & 256)], iq3xs_grid[qs[ 3] | ((qh[ib32+0] << 5) & 256)]}; - const uint32x4_t aux32x4_1 = {iq3xs_grid[qs[ 4] | ((qh[ib32+0] << 4) & 256)], iq3xs_grid[qs[ 5] | ((qh[ib32+0] << 3) & 256)], - iq3xs_grid[qs[ 6] | ((qh[ib32+0] << 2) & 256)], iq3xs_grid[qs[ 7] | ((qh[ib32+0] << 1) & 256)]}; - const uint32x4_t aux32x4_2 = {iq3xs_grid[qs[ 8] | ((qh[ib32+1] << 8) & 256)], iq3xs_grid[qs[ 9] | ((qh[ib32+1] << 7) & 256)], - iq3xs_grid[qs[10] | ((qh[ib32+1] << 6) & 256)], iq3xs_grid[qs[11] | ((qh[ib32+1] << 5) & 256)]}; - const uint32x4_t aux32x4_3 = {iq3xs_grid[qs[12] | ((qh[ib32+1] << 4) & 256)], iq3xs_grid[qs[13] | ((qh[ib32+1] << 3) & 256)], - iq3xs_grid[qs[14] | ((qh[ib32+1] << 2) & 256)], iq3xs_grid[qs[15] | ((qh[ib32+1] << 1) & 256)]}; - qs += 16; + const uint8x16_t idx_l = vld1q_u8(qs); qs += 16; + const uint16x8_t idx_1 = vorrq_u16(vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+0]), idx_shift), idx_mask1), + vmovl_u8(vget_low_u8(idx_l))); + const uint16x8_t idx_2 = vorrq_u16(vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), idx_shift), idx_mask1), + vmovl_u8(vget_high_u8(idx_l))); + q3s.val[0] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_1))), idx_mask2)); + q3s.val[1] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_1))), idx_mask2)); + q3s.val[2] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_2))), idx_mask2)); + q3s.val[3] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_2))), idx_mask2)); + q3s.val[0] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[0], m1), m0), 1), 1), m1); + q3s.val[1] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[1], m1), m0), 1), 1), m1); + q3s.val[2] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[2], m1), m0), 1), 1), m1); + q3s.val[3] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[3], m1), m0), 1), 1), m1); vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16))); vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); - vs.val[0] = vceqq_u8(vs.val[0], mask2); - vs.val[1] = vceqq_u8(vs.val[1], mask2); + vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), vreinterpretq_u8_s8(m1)); + vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), vreinterpretq_u8_s8(m1)); - q3s.val[0] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[0], vreinterpretq_u8_u32(aux32x4_0))), vreinterpretq_s8_u8(vs.val[0])); - q3s.val[1] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[1], vreinterpretq_u8_u32(aux32x4_1))), vreinterpretq_s8_u8(vs.val[1])); + q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), q3s.val[0]); + q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), q3s.val[1]); vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16))); vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); - vs.val[0] = vceqq_u8(vs.val[0], mask2); - vs.val[1] = vceqq_u8(vs.val[1], mask2); + vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), vreinterpretq_u8_s8(m1)); + vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), vreinterpretq_u8_s8(m1)); signs += 4; - q3s.val[2] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[0], vreinterpretq_u8_u32(aux32x4_2))), vreinterpretq_s8_u8(vs.val[0])); - q3s.val[3] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[1], vreinterpretq_u8_u32(aux32x4_3))), vreinterpretq_s8_u8(vs.val[1])); + q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), q3s.val[2]); + q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), q3s.val[3]); const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]); const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]); @@ -10102,7 +10115,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v } sumf += d*(sumi1 + sumi2); } - *s = 0.25f * sumf; + *s = sumf; #elif defined(__AVX2__) From 3000e0ac9e22828aceab6189cd87b24ea1ad8553 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 3 Mar 2024 06:41:58 +0200 Subject: [PATCH 08/24] iq3_s_mult: Metal works - slower than lookup --- ggml-metal.metal | 115 ++++++++++++++--------------------------------- 1 file changed, 34 insertions(+), 81 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 689411903..1615f8cea 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2546,6 +2546,7 @@ typedef struct { uint8_t signs[QK_K/8]; uint8_t scales[IQ3S_N_SCALE]; } block_iq3_s; +#define IQ3S_MULTIPLIER 190842953 typedef struct { half d; @@ -4083,73 +4084,6 @@ constexpr constant static uint32_t iq3xxs_grid[256] = { 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, }; -constexpr constant static uint32_t iq3xs_grid[512] = { - 0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14, - 0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414, - 0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24, - 0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c, - 0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c, - 0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34, - 0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c, - 0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414, - 0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c, - 0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404, - 0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434, - 0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c, - 0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404, - 0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414, - 0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414, - 0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404, - 0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c, - 0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c, - 0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404, - 0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e, - 0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14, - 0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c, - 0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424, - 0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c, - 0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c, - 0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e, - 0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e, - 0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e, - 0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424, - 0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e, - 0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424, - 0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404, - 0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c, - 0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e, - 0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c, - 0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c, - 0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c, - 0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404, - 0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04, - 0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c, - 0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414, - 0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c, - 0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c, - 0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424, - 0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c, - 0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c, - 0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414, - 0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c, - 0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e, - 0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04, - 0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424, - 0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14, - 0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34, - 0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c, - 0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434, - 0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c, - 0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424, - 0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24, - 0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24, - 0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e, - 0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c, - 0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c, - 0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c, - 0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404, -}; - #define NGRID_IQ1S 512 constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = { 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000, @@ -4757,14 +4691,23 @@ void kernel_mul_mv_iq3_s_f32_impl( threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; { + uint32_t aux32; + thread int8_t * q = (thread int8_t *)&aux32; int nval = 8; int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq3xs_grid[pos + i]; + for (int i = 0; i < nval; ++i) { + aux32 = (IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f; + for (int k = 0; k < 4; ++k) q[k] = 2*((q[k]-1)/2) + 1; + values[pos + i] = aux32; + } threadgroup_barrier(mem_flags::mem_threadgroup); } const int ix = tiisg; + uint32_t aux32[2]; + thread const int8_t * grid = (thread const int8_t *)aux32; + device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { @@ -4786,13 +4729,21 @@ void kernel_mul_mv_iq3_s_f32_impl( for (int row = 0; row < N_DST; row++) { const float db = dh[0]; - const float d = db * (0.5f + ((sc[0] >> 4*(ib%2)) & 0xf)); + const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf)); float2 sum = {0}; for (int l = 0; l < 4; ++l) { - const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))); - const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))); + //aux32[0] = (IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))) & 0x0f0f0f0f; + //aux32[1] = (IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))) & 0x0f0f0f0f; + //threadgroup const uint8_t * grid1 = (threadgroup const uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))); + //threadgroup const uint8_t * grid2 = (threadgroup const uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))); + threadgroup const uint8_t * grid1 = (threadgroup const uint8_t *)(values + qs[2*l+0] + + select(0, 256, qh[0] & kmask_iq2xs[2*l+0])); + threadgroup const uint8_t * grid2 = (threadgroup const uint8_t *)(values + qs[2*l+1] + + select(0, 256, qh[0] & kmask_iq2xs[2*l+1])); for (int j = 0; j < 4; ++j) { + //sum[0] += yl[8*l + j + 0] * (2*((grid[j+0] - 1)/2) + 1) * select(1, -1, signs[l] & kmask_iq2xs[j+0]); + //sum[1] += yl[8*l + j + 4] * (2*((grid[j+4] - 1)/2) + 1) * select(1, -1, signs[l] & kmask_iq2xs[j+4]); sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); } @@ -4812,7 +4763,7 @@ void kernel_mul_mv_iq3_s_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f; + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; } } } @@ -5703,18 +5654,20 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & device const uint8_t * qs = xb->qs + 8*ib32; device const uint8_t * signs = xb->signs + 4*ib32 + 2*il; const uint8_t qh = xb->qh[ib32] >> 4*il; - const float dl = d * (0.5f + ((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * 0.5f; - constant uint8_t * grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+0] | ((qh << 8) & 256))); - constant uint8_t * grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+1] | ((qh << 7) & 256))); + const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)); + uint32_t aux32[2]; + thread const int8_t * grid = (thread const int8_t *)aux32; + aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f; + aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f; for (int i = 0; i < 4; ++i) { - reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]); - reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]); + reg[0][i] = dl * (2*((grid[i+0]-1)/2)+1) * select(1, -1, signs[0] & kmask_iq2xs[i+0]); + reg[1][i] = dl * (2*((grid[i+4]-1)/2)+1) * select(1, -1, signs[0] & kmask_iq2xs[i+4]); } - grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+2] | ((qh << 6) & 256))); - grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+3] | ((qh << 5) & 256))); + aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+2] | ((qh << 6) & 256))) & 0x0f0f0f0f; + aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+3] | ((qh << 5) & 256))) & 0x0f0f0f0f; for (int i = 0; i < 4; ++i) { - reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); - reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); + reg[2][i] = dl * (2*((grid[i+0]-1)/2)+1) * select(1, -1, signs[1] & kmask_iq2xs[i+0]); + reg[3][i] = dl * (2*((grid[i+4]-1)/2)+1) * select(1, -1, signs[1] & kmask_iq2xs[i+4]); } } From fe3c20b251a5f7b9f90b41799631096f308c1d9b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 3 Mar 2024 07:51:20 +0200 Subject: [PATCH 09/24] iq3_s_mult: quantization tuning --- ggml-quants.c | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/ggml-quants.c b/ggml-quants.c index cb1c43461..65531a79d 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -11986,9 +11986,10 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo scales[ib] = 0; continue; } + for (int k = 0; k < bs4; ++k) is_on_grid[k] = false; float best = 0; float scale = max/(2*kMaxQ-1); - for (int is = -15; is <= 15; ++is) { + for (int is = -9; is <= 9; ++is) { float id = (2*kMaxQ-1+is*0.2f)/max; float this_scale = 1/id; for (int k = 0; k < bs4; ++k) { @@ -12024,7 +12025,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo if (n_not_ongrid > 0 && scale > 0) { float id = 1/scale; for (int k = 0; k < bs4; ++k) { - if (is_on_grid[k]) continue; + //if (is_on_grid[k]) continue; uint16_t u = 0; for (int i = 0; i < 4; ++i) { int l = nearest_int(0.5f*(id*xval[4*k+i]-1)); @@ -12074,24 +12075,14 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo scales[ib] = scale; max_scale = MAX(max_scale, scale); - //for (int k = 0; k < bs8; ++k) { - // for (int i = 0; i < 8; ++i) { - // float diff = scale*(2*L[8*k+i] + 1) * (block_signs[k] & (1 << i) ? -1 : 1) - xb[8*k+i]; - // block_mse += diff*diff; - // } - //} - } - //printf("Block %d: rmse = %g\n", ibl, (double)sqrtf(block_mse/QK_K)); - if (!max_scale) { continue; } float d = max_scale/31; - //y[ibl].d = GGML_FP32_TO_FP16(d * 1.025f); //1.02f); //1.0125f); - y[ibl].d = GGML_FP32_TO_FP16(d * 1.033f); //1.04f); //1.02f); //1.0125f); + y[ibl].d = GGML_FP32_TO_FP16(d * 1.025f); //1.033f); float id = 1/d; for (int ib = 0; ib < QK_K/block_size; ib += 2) { int l1 = nearest_int(0.5f*(id*scales[ib+0]-1)); From 726aed307a049e09e308c90bad7c0f0aea75a0fc Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 3 Mar 2024 08:51:28 +0200 Subject: [PATCH 10/24] iq3_s_mult: alternative multiplier / bit twidling --- ggml-cuda.cu | 31 +++++++++++++++++++--------- ggml-quants.c | 56 ++++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 66 insertions(+), 21 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 6dec7f576..e82fbf06c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2370,9 +2370,10 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds } -//#define IQ3S_MULTIPLIER 746226 -//#define IQ3S_MULTIPLIER 677595 -#define IQ3S_MULTIPLIER 190842953LL +//#define IQ3S_MULTIPLIER 190842953LL + +//#define IQ3S_MULTIPLIER 5718026 +#define IQ3S_MULTIPLIER 898886 template static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { @@ -2390,11 +2391,17 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_ const int8_t * grid = (const int8_t *)aux32; const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)); const uint8_t signs = x[i].signs[4*ib + il]; - aux32[0] = ((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; - aux32[1] = ((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; + //aux32[0] = ((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; + //aux32[1] = ((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; + aux32[0] = (((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + aux32[1] = (((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101; - aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101; + //aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101; + //aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101; + //aux32[0] = ((__vsub4(aux32[0], 0x01010101) >> 1) << 1) | 0x01010101; + //aux32[1] = ((__vsub4(aux32[1], 0x01010101) >> 1) << 1) | 0x01010101; + aux32[0] = ((aux32[0] >> 1) << 1) | 0x01010101; + aux32[1] = ((aux32[1] >> 1) << 1) | 0x01010101; uint32_t signs0 = __vcmpeq4(((signs & 0xf) * 0x01010101) & 0x08040201, 0x08040201); uint32_t signs1 = __vcmpeq4(((signs >> 4) * 0x01010101) & 0x08040201, 0x08040201); aux32[0] = __vsub4(aux32[0] ^ signs0, signs0); @@ -5220,10 +5227,16 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( const int8_t * q8 = bq8_1[ib32].qs; int sumi = 0; for (int l = 0; l < 4; ++l) { + //aux32[0] = ((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; + //aux32[1] = ((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; + //aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0) >> 1) & 0x07070707) << 1) | 0x01010101; + //aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0) >> 1) & 0x07070707) << 1) | 0x01010101; aux32[0] = ((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; aux32[1] = ((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; - aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0) >> 1) & 0x07070707) << 1) | 0x01010101; - aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0) >> 1) & 0x07070707) << 1) | 0x01010101; + //aux32[0] = ((__vsub4(aux32[0], 0x01010101) >> 1) << 1) | 0x01010101; + //aux32[1] = ((__vsub4(aux32[1], 0x01010101) >> 1) << 1) | 0x01010101; + aux32[0] = ((aux32[0] >> 1) << 1) | 0x01010101; + aux32[1] = ((aux32[1] >> 1) << 1) | 0x01010101; uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201); uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201); const int grid_l = __vsub4(aux32[0] ^ signs0, signs0); diff --git a/ggml-quants.c b/ggml-quants.c index 65531a79d..83846a1f2 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -4125,7 +4125,13 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y //#define IQ3S_MULTIPLIER 746226 //#define IQ3S_MULTIPLIER 717154 //#define IQ3S_MULTIPLIER 677595 -#define IQ3S_MULTIPLIER 190842953 + +// Best PPL +//#define IQ3S_MULTIPLIER 190842953 +// +//#define IQ3S_MULTIPLIER 5718026 +#define IQ3S_MULTIPLIER 898886 + #define IQ3S_BITS 3 void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int k) { @@ -4146,20 +4152,34 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in const float db1 = d * (1 + 2*(x[i].scales[ib32/2] & 0xf)); const float db2 = d * (1 + 2*(x[i].scales[ib32/2] >> 4)); for (int l = 0; l < 4; ++l) { - aux32[0] = ((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; - aux32[1] = ((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; + //aux32[0] = ((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; + //aux32[1] = ((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; + //for (int j = 0; j < 8; ++j) { + // y[j] = db1 * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); + //} + aux32[0] = (((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + aux32[1] = (((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + aux32[0] = ((aux32[0] >> 1) << 1) | 0x01010101; + aux32[1] = ((aux32[1] >> 1) << 1) | 0x01010101; for (int j = 0; j < 8; ++j) { - y[j] = db1 * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); + y[j] = db1 * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); } y += 8; } qs += 8; signs += 4; for (int l = 0; l < 4; ++l) { - aux32[0] = ((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; - aux32[1] = ((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; + //aux32[0] = ((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; + //aux32[1] = ((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; + //for (int j = 0; j < 8; ++j) { + // y[j] = db2 * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); + //} + aux32[0] = (((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + aux32[1] = (((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + aux32[0] = ((aux32[0] >> 1) << 1) | 0x01010101; + aux32[1] = ((aux32[1] >> 1) << 1) | 0x01010101; for (int j = 0; j < 8; ++j) { - y[j] = db2 * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); + y[j] = db2 * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); } y += 8; } @@ -10136,7 +10156,12 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const __m256i m1 = _mm256_set1_epi8(1); const __m256i m7 = _mm256_set1_epi32(0x07070707); const __m256i m15 = _mm256_set1_epi32(0x0f0f0f0f); - const __m256i m0 = _mm256_setzero_si256(); + //const __m256i m0 = _mm256_setzero_si256(); + + // aux32[0] = (((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + // aux32[1] = (((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + // aux32[0] = ((aux32[0] >> 1) << 1) | 0x01010101; + // aux32[1] = ((aux32[1] >> 1) << 1) | 0x01010101; __m256 accumf = _mm256_setzero_ps(); for (int i = 0; i < nb; ++i) { @@ -10158,11 +10183,16 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const __m256i idx_32_h = _mm256_or_si256(idx_h_h, _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l_16, 1))); // v = MAX(((IQ3S_MULTIPLIER * idx) & 0x0f0f0f0f) - 1, 0) - const __m256i idx_l = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1), m0); + //const __m256i idx_l = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1), m0); // v = (((v >> 1) & 0x07070707) << 1) | 0x01010101 + //const __m256i q2_1 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_l, 1), m7), 1), m1); + //const __m256i idx_h = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1), m0); + //const __m256i q2_2 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_h, 1), m7), 1), m1); + + const __m256i idx_l = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1); const __m256i q2_1 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_l, 1), m7), 1), m1); - const __m256i idx_h = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1), m0); + const __m256i idx_h = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1); const __m256i q2_2 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_h, 1), m7), 1), m1); __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16)); @@ -11392,9 +11422,11 @@ static void iq3xs_init_grid512(void) { const uint8_t * q4 = (const uint8_t *)&aux32; for (int k = 0; k < grid_size; ++k) { int8_t * pos = (int8_t *)(the_grid + k); - aux32 = ((uint64_t)IQ3S_MULTIPLIER * k) & 0x0f0f0f0f; + //aux32 = ((uint64_t)IQ3S_MULTIPLIER * k) & 0x0f0f0f0f; + aux32 = (((uint64_t)IQ3S_MULTIPLIER * k) & 0x0f0f0f0f) | 0x01010101; for (int i = 0; i < 4; ++i) { - pos[i] = 2*((q4[i]-1)/2) + 1; + //pos[i] = 2*((q4[i]-1)/2) + 1; + pos[i] = 2*(q4[i]/2) + 1; } } From b6402fa757e650847f1183ebab3ddfa48f2aeb88 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 3 Mar 2024 10:43:53 +0200 Subject: [PATCH 11/24] iq3_s_mult: ifdef'd slow / fast versions --- ggml-cuda.cu | 49 +++++++++++++--------- ggml-quants.c | 113 +++++++++++++++++++------------------------------- 2 files changed, 71 insertions(+), 91 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index e82fbf06c..3b8032569 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2370,10 +2370,13 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds } -//#define IQ3S_MULTIPLIER 190842953LL - -//#define IQ3S_MULTIPLIER 5718026 -#define IQ3S_MULTIPLIER 898886 +#ifdef IQ3S_SLOW_MULT +// Better (lower PPL), but requires more bit twidling, so slower +#define IQ3S_MULTIPLIER 190842953LL +#else +//#define IQ3S_MULTIPLIER 898886 +#define IQ3S_MULTIPLIER 842866 +#endif template static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { @@ -2391,17 +2394,18 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_ const int8_t * grid = (const int8_t *)aux32; const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)); const uint8_t signs = x[i].signs[4*ib + il]; - //aux32[0] = ((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; - //aux32[1] = ((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; +#ifdef IQ3S_SLOW_MULT + aux32[0] = ((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; + aux32[1] = ((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; +#else aux32[0] = (((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; aux32[1] = (((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; +#endif #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - //aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101; - //aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101; - //aux32[0] = ((__vsub4(aux32[0], 0x01010101) >> 1) << 1) | 0x01010101; - //aux32[1] = ((__vsub4(aux32[1], 0x01010101) >> 1) << 1) | 0x01010101; - aux32[0] = ((aux32[0] >> 1) << 1) | 0x01010101; - aux32[1] = ((aux32[1] >> 1) << 1) | 0x01010101; +#ifdef IQ3S_SLOW_MULT + aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101; + aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101; +#endif uint32_t signs0 = __vcmpeq4(((signs & 0xf) * 0x01010101) & 0x08040201, 0x08040201); uint32_t signs1 = __vcmpeq4(((signs >> 4) * 0x01010101) & 0x08040201, 0x08040201); aux32[0] = __vsub4(aux32[0] ^ signs0, signs0); @@ -2410,9 +2414,15 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_ y[j] = d * grid[j]; } #else +#ifdef IQ3S_SLOW_MULT for (int j = 0; j < 8; ++j) { y[j] = d * (2*((grid[j]-1)/2) + 1) * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } +#else + for (int j = 0; j < 8; ++j) { + y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } +#endif #endif #else assert(false); @@ -5227,16 +5237,15 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( const int8_t * q8 = bq8_1[ib32].qs; int sumi = 0; for (int l = 0; l < 4; ++l) { - //aux32[0] = ((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; - //aux32[1] = ((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; - //aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0) >> 1) & 0x07070707) << 1) | 0x01010101; - //aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0) >> 1) & 0x07070707) << 1) | 0x01010101; +#ifdef IQ3S_SLOW_MULT aux32[0] = ((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; aux32[1] = ((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; - //aux32[0] = ((__vsub4(aux32[0], 0x01010101) >> 1) << 1) | 0x01010101; - //aux32[1] = ((__vsub4(aux32[1], 0x01010101) >> 1) << 1) | 0x01010101; - aux32[0] = ((aux32[0] >> 1) << 1) | 0x01010101; - aux32[1] = ((aux32[1] >> 1) << 1) | 0x01010101; + aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0) >> 1) & 0x07070707) << 1) | 0x01010101; + aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0) >> 1) & 0x07070707) << 1) | 0x01010101; +#else + aux32[0] = (((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + aux32[1] = (((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; +#endif uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201); uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201); const int grid_l = __vsub4(aux32[0] ^ signs0, signs0); diff --git a/ggml-quants.c b/ggml-quants.c index 83846a1f2..7ce2f077a 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -4121,16 +4121,13 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y // ====================== 3.3125 bpw (de)-quantization -//#define IQ3S_MULTIPLIER 2469109 -//#define IQ3S_MULTIPLIER 746226 -//#define IQ3S_MULTIPLIER 717154 -//#define IQ3S_MULTIPLIER 677595 - +#ifdef IQ3S_SLOW_MULT // Best PPL -//#define IQ3S_MULTIPLIER 190842953 -// -//#define IQ3S_MULTIPLIER 5718026 -#define IQ3S_MULTIPLIER 898886 +#define IQ3S_MULTIPLIER 190842953 +#else +//#define IQ3S_MULTIPLIER 898886 +#define IQ3S_MULTIPLIER 842866 +#endif #define IQ3S_BITS 3 @@ -4152,32 +4149,34 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in const float db1 = d * (1 + 2*(x[i].scales[ib32/2] & 0xf)); const float db2 = d * (1 + 2*(x[i].scales[ib32/2] >> 4)); for (int l = 0; l < 4; ++l) { - //aux32[0] = ((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; - //aux32[1] = ((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; - //for (int j = 0; j < 8; ++j) { - // y[j] = db1 * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); - //} +#ifdef IQ3S_SLOW_MULT + aux32[0] = ((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; + aux32[1] = ((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; + for (int j = 0; j < 8; ++j) { + y[j] = db1 * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); + } +#else aux32[0] = (((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; aux32[1] = (((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; - aux32[0] = ((aux32[0] >> 1) << 1) | 0x01010101; - aux32[1] = ((aux32[1] >> 1) << 1) | 0x01010101; for (int j = 0; j < 8; ++j) { y[j] = db1 * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); } +#endif y += 8; } qs += 8; signs += 4; for (int l = 0; l < 4; ++l) { - //aux32[0] = ((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; - //aux32[1] = ((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; - //for (int j = 0; j < 8; ++j) { - // y[j] = db2 * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); - //} +#ifdef IQ3S_SLOW_MULT + aux32[0] = ((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; + aux32[1] = ((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; + for (int j = 0; j < 8; ++j) { + y[j] = db2 * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); + } +#else aux32[0] = (((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; aux32[1] = (((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; - aux32[0] = ((aux32[0] >> 1) << 1) | 0x01010101; - aux32[1] = ((aux32[1] >> 1) << 1) | 0x01010101; +#endif for (int j = 0; j < 8; ++j) { y[j] = db2 * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); } @@ -4187,34 +4186,6 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in qs += 8; signs += 4; } - - //for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - // const float db1 = d * (0.5f + (x[i].scales[ib32/2] & 0xf)) * 0.5f; - // const float db2 = d * (0.5f + (x[i].scales[ib32/2] >> 4)) * 0.5f; - // for (int l = 0; l < 4; ++l) { - // const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))); - // const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))); - // for (int j = 0; j < 4; ++j) { - // y[j+0] = db1 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f); - // y[j+4] = db1 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f); - // } - // y += 8; - // } - // qs += 8; - // signs += 4; - // for (int l = 0; l < 4; ++l) { - // const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256))); - // const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256))); - // for (int j = 0; j < 4; ++j) { - // y[j+0] = db2 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f); - // y[j+4] = db2 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f); - // } - // y += 8; - // } - // qh += 2; - // qs += 8; - // signs += 4; - //} } } @@ -10154,14 +10125,11 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8); const __m256i idx_mult = _mm256_set1_epi32(IQ3S_MULTIPLIER); const __m256i m1 = _mm256_set1_epi8(1); - const __m256i m7 = _mm256_set1_epi32(0x07070707); const __m256i m15 = _mm256_set1_epi32(0x0f0f0f0f); - //const __m256i m0 = _mm256_setzero_si256(); - - // aux32[0] = (((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; - // aux32[1] = (((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; - // aux32[0] = ((aux32[0] >> 1) << 1) | 0x01010101; - // aux32[1] = ((aux32[1] >> 1) << 1) | 0x01010101; +#ifdef IQ3S_SLOW_MULT + const __m256i m7 = _mm256_set1_epi32(0x07070707); + const __m256i m0 = _mm256_setzero_si256(); +#endif __m256 accumf = _mm256_setzero_ps(); for (int i = 0; i < nb; ++i) { @@ -10182,18 +10150,19 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const __m256i idx_32_l = _mm256_or_si256(idx_h_l, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l_16))); const __m256i idx_32_h = _mm256_or_si256(idx_h_h, _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l_16, 1))); - // v = MAX(((IQ3S_MULTIPLIER * idx) & 0x0f0f0f0f) - 1, 0) - //const __m256i idx_l = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1), m0); - // v = (((v >> 1) & 0x07070707) << 1) | 0x01010101 - //const __m256i q2_1 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_l, 1), m7), 1), m1); - //const __m256i idx_h = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1), m0); - //const __m256i q2_2 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_h, 1), m7), 1), m1); - - const __m256i idx_l = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1); +#ifdef IQ3S_SLOW_MULT + const __m256i idx_l = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1), m0); const __m256i q2_1 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_l, 1), m7), 1), m1); - - const __m256i idx_h = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1); + const __m256i idx_h = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1), m0); const __m256i q2_2 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_h, 1), m7), 1), m1); +#else + //const __m256i idx_l = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1); + //const __m256i q2_1 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_l, 1), m7), 1), m1); + //const __m256i idx_h = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1); + //const __m256i q2_2 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_h, 1), m7), 1), m1); + const __m256i q2_1 = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1); + const __m256i q2_2 = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1); +#endif __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16)); aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); @@ -11422,11 +11391,13 @@ static void iq3xs_init_grid512(void) { const uint8_t * q4 = (const uint8_t *)&aux32; for (int k = 0; k < grid_size; ++k) { int8_t * pos = (int8_t *)(the_grid + k); - //aux32 = ((uint64_t)IQ3S_MULTIPLIER * k) & 0x0f0f0f0f; +#ifdef IQ3S_SLOW_MULT + aux32 = ((uint64_t)IQ3S_MULTIPLIER * k) & 0x0f0f0f0f; +#else aux32 = (((uint64_t)IQ3S_MULTIPLIER * k) & 0x0f0f0f0f) | 0x01010101; +#endif for (int i = 0; i < 4; ++i) { - //pos[i] = 2*((q4[i]-1)/2) + 1; - pos[i] = 2*(q4[i]/2) + 1; + pos[i] = 2*((q4[i]-1)/2) + 1; } } From 5b9c8785faf7566ad3ab797d34dd26a8c1b18b67 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 3 Mar 2024 11:30:01 +0200 Subject: [PATCH 12/24] iq3s_mult: ARM and Metal --- ggml-metal.metal | 49 ++++++++++++++++++++++++++++++++++++------------ ggml-quants.c | 11 +++++++++-- 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 1615f8cea..69a928c24 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2546,7 +2546,12 @@ typedef struct { uint8_t signs[QK_K/8]; uint8_t scales[IQ3S_N_SCALE]; } block_iq3_s; +#ifdef IQ3S_SLOW_MULT #define IQ3S_MULTIPLIER 190842953 +#else +//#define IQ3S_MULTIPLIER 898886 +#define IQ3S_MULTIPLIER 842866 +#endif typedef struct { half d; @@ -4691,15 +4696,21 @@ void kernel_mul_mv_iq3_s_f32_impl( threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; { - uint32_t aux32; - thread int8_t * q = (thread int8_t *)&aux32; int nval = 8; int pos = (32*sgitg + tiisg)*nval; +#ifdef IQ3S_SLOW_MULT + uint32_t aux32; + thread int8_t * q = (thread int8_t *)&aux32; for (int i = 0; i < nval; ++i) { aux32 = (IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f; for (int k = 0; k < 4; ++k) q[k] = 2*((q[k]-1)/2) + 1; values[pos + i] = aux32; } +#else + for (int i = 0; i < nval; ++i) { + values[pos + i] = ((IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f) | 0x01010101; + } +#endif threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -4733,17 +4744,16 @@ void kernel_mul_mv_iq3_s_f32_impl( float2 sum = {0}; for (int l = 0; l < 4; ++l) { - //aux32[0] = (IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))) & 0x0f0f0f0f; - //aux32[1] = (IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))) & 0x0f0f0f0f; - //threadgroup const uint8_t * grid1 = (threadgroup const uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))); - //threadgroup const uint8_t * grid2 = (threadgroup const uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))); - threadgroup const uint8_t * grid1 = (threadgroup const uint8_t *)(values + qs[2*l+0] + - select(0, 256, qh[0] & kmask_iq2xs[2*l+0])); - threadgroup const uint8_t * grid2 = (threadgroup const uint8_t *)(values + qs[2*l+1] + - select(0, 256, qh[0] & kmask_iq2xs[2*l+1])); + // This is slower than pre-computing the grid in shared memory and loading from there + //aux32[0] = ((IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))) & 0x0f0f0f0f) | 0x01010101; + //aux32[1] = ((IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))) & 0x0f0f0f0f) | 0x01010101; + //for (int j = 0; j < 4; ++j) { + // sum[0] += yl[8*l + j + 0] * grid[j+0] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); + // sum[1] += yl[8*l + j + 4] * grid[j+4] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); + //} + threadgroup const uint8_t * grid1 = (threadgroup const uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))); + threadgroup const uint8_t * grid2 = (threadgroup const uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))); for (int j = 0; j < 4; ++j) { - //sum[0] += yl[8*l + j + 0] * (2*((grid[j+0] - 1)/2) + 1) * select(1, -1, signs[l] & kmask_iq2xs[j+0]); - //sum[1] += yl[8*l + j + 4] * (2*((grid[j+4] - 1)/2) + 1) * select(1, -1, signs[l] & kmask_iq2xs[j+4]); sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); } @@ -5657,6 +5667,7 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)); uint32_t aux32[2]; thread const int8_t * grid = (thread const int8_t *)aux32; +#ifdef IQ3S_SLOW)MULT aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f; aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f; for (int i = 0; i < 4; ++i) { @@ -5669,6 +5680,20 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg[2][i] = dl * (2*((grid[i+0]-1)/2)+1) * select(1, -1, signs[1] & kmask_iq2xs[i+0]); reg[3][i] = dl * (2*((grid[i+4]-1)/2)+1) * select(1, -1, signs[1] & kmask_iq2xs[i+4]); } +#else + aux32[0] = ((IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f) | 0x01010101; + aux32[1] = ((IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f) | 0x01010101; + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * grid[i+0] * select(1, -1, signs[0] & kmask_iq2xs[i+0]); + reg[1][i] = dl * grid[i+4] * select(1, -1, signs[0] & kmask_iq2xs[i+4]); + } + aux32[0] = ((IQ3S_MULTIPLIER * (qs[4*il+2] | ((qh << 6) & 256))) & 0x0f0f0f0f) | 0x01010101; + aux32[1] = ((IQ3S_MULTIPLIER * (qs[4*il+3] | ((qh << 5) & 256))) & 0x0f0f0f0f) | 0x01010101; + for (int i = 0; i < 4; ++i) { + reg[2][i] = dl * grid[i+0] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); + reg[3][i] = dl * grid[i+4] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); + } +#endif } template diff --git a/ggml-quants.c b/ggml-quants.c index 7ce2f077a..e4478102f 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -10070,6 +10070,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v vmovl_u8(vget_low_u8(idx_l))); const uint16x8_t idx_2 = vorrq_u16(vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), idx_shift), idx_mask1), vmovl_u8(vget_high_u8(idx_l))); +#ifdef IQ3S_SLOW_MULT q3s.val[0] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_1))), idx_mask2)); q3s.val[1] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_1))), idx_mask2)); q3s.val[2] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_2))), idx_mask2)); @@ -10078,6 +10079,12 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v q3s.val[1] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[1], m1), m0), 1), 1), m1); q3s.val[2] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[2], m1), m0), 1), 1), m1); q3s.val[3] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[3], m1), m0), 1), 1), m1); +#else + q3s.val[0] = vorrq_s8(vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_1))), idx_mask2)), m1); + q3s.val[1] = vorrq_s8(vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_1))), idx_mask2)), m1); + q3s.val[2] = vorrq_s8(vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_2))), idx_mask2)), m1); + q3s.val[3] = vorrq_s8(vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_2))), idx_mask2)), m1); +#endif vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16))); vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); @@ -10094,8 +10101,6 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), vreinterpretq_u8_s8(m1)); vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), vreinterpretq_u8_s8(m1)); - signs += 4; - q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), q3s.val[2]); q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), q3s.val[3]); @@ -10103,6 +10108,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]); sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32/2] & 0xf)); sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32/2] >> 4)); + + signs += 4; } sumf += d*(sumi1 + sumi2); } From 8b713a987e50af54c81ce7e6323826c836013eaa Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 3 Mar 2024 11:32:53 +0200 Subject: [PATCH 13/24] iq3s_mult: quantization tuning --- ggml-quants.c | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ggml-quants.c b/ggml-quants.c index e4478102f..e6f8389db 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -12092,7 +12092,11 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo } float d = max_scale/31; - y[ibl].d = GGML_FP32_TO_FP16(d * 1.025f); //1.033f); +#ifdef IQ3S_SLOW_MULT + y[ibl].d = GGML_FP32_TO_FP16(d * 1.025f); +#else + y[ibl].d = GGML_FP32_TO_FP16(d * 1.030f); +#endif float id = 1/d; for (int ib = 0; ib < QK_K/block_size; ib += 2) { int l1 = nearest_int(0.5f*(id*scales[ib+0]-1)); From dbe98dfe703abfa343a7980ec72aeb8e50671b10 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 3 Mar 2024 13:13:52 +0200 Subject: [PATCH 14/24] iq3_s_mult: another alternative multiplier --- ggml-cuda.cu | 4 ++-- ggml-quants.c | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 3b8032569..3df5b142b 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2374,8 +2374,8 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds // Better (lower PPL), but requires more bit twidling, so slower #define IQ3S_MULTIPLIER 190842953LL #else -//#define IQ3S_MULTIPLIER 898886 -#define IQ3S_MULTIPLIER 842866 +#define IQ3S_MULTIPLIER 898886 +//#define IQ3S_MULTIPLIER 842866 #endif template diff --git a/ggml-quants.c b/ggml-quants.c index e6f8389db..f154d7c21 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -4125,8 +4125,8 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y // Best PPL #define IQ3S_MULTIPLIER 190842953 #else -//#define IQ3S_MULTIPLIER 898886 -#define IQ3S_MULTIPLIER 842866 +#define IQ3S_MULTIPLIER 898886 +//#define IQ3S_MULTIPLIER 842866 #endif #define IQ3S_BITS 3 From f4cb4eac45d703176130136672826c7598e5ba60 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 3 Mar 2024 16:43:00 +0200 Subject: [PATCH 15/24] iq3_s_mult: play with blocks of 16 This brings the bpw to 3.5625. We come close but don't quite match lookup with 3.4375 bpw (blocks of 32) --- ggml-cuda.cu | 8 +++++--- ggml-quants.c | 24 +++++++++++++++++------- ggml-quants.h | 3 ++- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 3df5b142b..37fdd10cb 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -544,14 +544,15 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong #define QR3_XS 8 #define QI3_XS (QK_K / (4*QR3_XS)) +#define IQ3S_BLOCK_SIZE 16 typedef struct { half d; uint8_t qs[QK_K/4]; uint8_t qh[QK_K/32]; uint8_t signs[QK_K/8]; - uint8_t scales[QK_K/64]; + uint8_t scales[QK_K/(2*IQ3S_BLOCK_SIZE)]; } block_iq3_s; -static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 27*(QK_K/64), "wrong iq3_s block size/padding"); +static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + QK_K/(2*IQ3S_BLOCK_SIZE), "wrong iq3_s block size/padding"); #define QR1_S 8 #define QI1_S (QK_K / (4*QR1_S)) @@ -2392,7 +2393,8 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_ const uint8_t * qs = x[i].qs + 8*ib; int32_t aux32[2]; const int8_t * grid = (const int8_t *)aux32; - const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)); + const int is = (32*ib + 8*il)/IQ3S_BLOCK_SIZE; + const float d = (float)x[i].d * (1 + 2*((x[i].scales[is/2] >> 4*(is%2)) & 0xf)); const uint8_t signs = x[i].signs[4*ib + il]; #ifdef IQ3S_SLOW_MULT aux32[0] = ((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; diff --git a/ggml-quants.c b/ggml-quants.c index f154d7c21..cfa36b310 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -4138,6 +4138,8 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in uint32_t aux32[2]; const int8_t * grid = (const int8_t *)aux32; + float db[64/IQ3S_BLOCK_SIZE]; + for (int i = 0; i < nb; i++) { const float d = GGML_FP16_TO_FP32(x[i].d); @@ -4146,20 +4148,28 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in const uint8_t * signs = x[i].signs; for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const float db1 = d * (1 + 2*(x[i].scales[ib32/2] & 0xf)); - const float db2 = d * (1 + 2*(x[i].scales[ib32/2] >> 4)); +#if IQ3S_BLOCK_SIZE == 16 + db[0] = d * (1 + 2*(x[i].scales[ib32+0] & 0xf)); + db[1] = d * (1 + 2*(x[i].scales[ib32+0] >> 4)); + db[2] = d * (1 + 2*(x[i].scales[ib32+1] & 0xf)); + db[3] = d * (1 + 2*(x[i].scales[ib32+1] >> 4)); +#else + db[0] = d * (1 + 2*(x[i].scales[ib32/2] & 0xf)); + db[1] = d * (1 + 2*(x[i].scales[ib32/2] >> 4)); +#endif for (int l = 0; l < 4; ++l) { + const float dl = db[8*l/IQ3S_BLOCK_SIZE]; #ifdef IQ3S_SLOW_MULT aux32[0] = ((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; aux32[1] = ((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; for (int j = 0; j < 8; ++j) { - y[j] = db1 * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); + y[j] = dl * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); } #else aux32[0] = (((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; aux32[1] = (((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; for (int j = 0; j < 8; ++j) { - y[j] = db1 * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); + y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); } #endif y += 8; @@ -4167,18 +4177,19 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in qs += 8; signs += 4; for (int l = 0; l < 4; ++l) { + const float dl = db[(8*l+32)/IQ3S_BLOCK_SIZE]; #ifdef IQ3S_SLOW_MULT aux32[0] = ((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; aux32[1] = ((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; for (int j = 0; j < 8; ++j) { - y[j] = db2 * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); + y[j] = dl * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); } #else aux32[0] = (((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; aux32[1] = (((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; #endif for (int j = 0; j < 8; ++j) { - y[j] = db2 * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); + y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); } y += 8; } @@ -12109,7 +12120,6 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo } } -#define IQ3S_BLOCK_SIZE 32 size_t quantize_iq3_s(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { (void)hist; GGML_ASSERT(n_per_row%QK_K == 0); diff --git a/ggml-quants.h b/ggml-quants.h index 2c61134c4..4ad5d69e7 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -201,10 +201,11 @@ typedef struct { static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding"); // 3.4375 bpw +#define IQ3S_BLOCK_SIZE 16 #if QK_K == 64 #define IQ3S_N_SCALE 2 #else -#define IQ3S_N_SCALE QK_K/64 +#define IQ3S_N_SCALE QK_K/(2*IQ3S_BLOCK_SIZE) #endif typedef struct { ggml_fp16_t d; From e5e72562c5d20b4de175ae6e387b91389e90eaa6 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 3 Mar 2024 18:50:26 +0200 Subject: [PATCH 16/24] iq3_s_mult: back to blocks of 32 --- ggml-quants.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-quants.h b/ggml-quants.h index 4ad5d69e7..cb7af5961 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -201,7 +201,7 @@ typedef struct { static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding"); // 3.4375 bpw -#define IQ3S_BLOCK_SIZE 16 +#define IQ3S_BLOCK_SIZE 32 #if QK_K == 64 #define IQ3S_N_SCALE 2 #else From f2c2bd6b26c129cb4598488f1438e8da7bcdc55a Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 3 Mar 2024 19:12:05 +0200 Subject: [PATCH 17/24] iq3_s_mult: also CUDA --- ggml-cuda.cu | 17 ++++++++++++++++- ggml-quants.c | 2 ++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 37fdd10cb..ff721ea43 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -544,7 +544,7 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong #define QR3_XS 8 #define QI3_XS (QK_K / (4*QR3_XS)) -#define IQ3S_BLOCK_SIZE 16 +#define IQ3S_BLOCK_SIZE 32 typedef struct { half d; uint8_t qs[QK_K/4]; @@ -5237,7 +5237,11 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( const int ib32 = iqs; const uint8_t * qs = bq2->qs + 8*ib32; const int8_t * q8 = bq8_1[ib32].qs; +#if IQ3S_BLOCK_SIZE == 32 int sumi = 0; +#else + int sumi[2] = {0, 0}; +#endif for (int l = 0; l < 4; ++l) { #ifdef IQ3S_SLOW_MULT aux32[0] = ((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; @@ -5252,12 +5256,23 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201); const int grid_l = __vsub4(aux32[0] ^ signs0, signs0); const int grid_h = __vsub4(aux32[1] ^ signs1, signs1); +#if IQ3S_BLOCK_SIZE == 32 sumi = __dp4a(grid_l, *((int *)q8+0), sumi); sumi = __dp4a(grid_h, *((int *)q8+1), sumi); +#else + sumi[l/2] = __dp4a(grid_l, *((int *)q8+0), sumi[l/2]); + sumi[l/2] = __dp4a(grid_h, *((int *)q8+1), sumi[l/2]); +#endif q8 += 8; } +#if IQ3S_BLOCK_SIZE == 32 const float d = (float)bq2->d * (1 + 2*((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds); return d * sumi; +#else + int ls1 = 1 + 2*(bq2->scales[ib32] & 0xf); + int ls2 = 1 + 2*(bq2->scales[ib32] >> 4); + return (float)bq2->d * __low2float(bq8_1[ib32].ds) * (ls1 * sumi[0] + ls2 * sumi[1]); +#endif #else assert(false); return 0.f; diff --git a/ggml-quants.c b/ggml-quants.c index cfa36b310..8d2cae527 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -10037,6 +10037,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(by); UNUSED(bs); + GGML_ASSERT(IQ3S_BLOCK_SIZE == 32 && "IQ3S_BLOCK_SIZE != 32 is not implemented"); + const block_iq3_s * restrict x = vx; const block_q8_K * restrict y = vy; From b48bf8b411ffd788a778d8b72dfa78b759b95069 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 3 Mar 2024 11:55:31 +0200 Subject: [PATCH 18/24] iq3_s_mult: scalar dot product --- ggml-quants.c | 106 +++++++++++++------------------------------------- 1 file changed, 28 insertions(+), 78 deletions(-) diff --git a/ggml-quants.c b/ggml-quants.c index 8d2cae527..70af6e623 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -3789,73 +3789,6 @@ static const uint32_t iq3xxs_grid[256] = { 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, }; -static const uint32_t iq3xs_grid[512] = { -0x04040404, 0x04142c14, 0x042c2424, 0x0404143c, 0x04140c0c, 0x042c0424, 0x04043434, 0x041c240c, -0x04341c1c, 0x040c0c34, 0x041c0404, 0x0434341c, 0x040c2c2c, 0x04241c04, 0x043c1414, 0x0414042c, -0x04243c04, 0x04042c14, 0x04142424, 0x042c143c, 0x04040c0c, 0x0c1c0424, 0x0c2c3434, 0x0c04240c, -0x0c1c141c, 0x0c340c34, 0x0c0c0404, 0x0c24341c, 0x0c34242c, 0x0c0c1c04, 0x0c240c14, 0x0c3c042c, -0x0c143404, 0x0c2c2c14, 0x14041c24, 0x1414143c, 0x142c040c, 0x14043c24, 0x141c2c34, 0x1434240c, -0x140c141c, 0x141c0c34, 0x14340404, 0x140c341c, 0x1424242c, 0x143c1c04, 0x14140c14, 0x1424042c, -0x1c043404, 0x1c142414, 0x1c2c1c24, 0x1c040c3c, 0x1c1c040c, 0x1c2c3424, 0x1c042c34, 0x1c1c1c0c, -0x1c34141c, 0x1c0c0434, 0x1c243c04, 0x1c342c1c, 0x1c0c242c, 0x1c241404, 0x243c0c14, 0x2414042c, -0x242c3404, 0x24042414, 0x24141c24, 0x242c0c3c, 0x2404040c, 0x241c3424, 0x24342434, 0x24041c0c, -0x241c0c1c, 0x24340434, 0x240c3404, 0x2c242c1c, 0x2c3c1c2c, 0x2c141404, 0x2c240414, 0x2c043c2c, -0x2c142c04, 0x2c2c2414, 0x2c041424, 0x2c1c0c3c, 0x2c2c040c, 0x2c043424, 0x2c1c2434, 0x2c341c0c, -0x2c0c0c1c, 0x34240434, 0x34343404, 0x340c2c1c, 0x34241c2c, 0x343c1404, 0x34140414, 0x342c342c, -0x34042c04, 0x34141c14, 0x342c1424, 0x3404043c, 0x341c3c0c, 0x34342c24, 0x3c042434, 0x3c1c140c, -0x3c340c1c, 0x3c0c0434, 0x3c243404, 0x3c3c241c, 0x3c0c1c2c, 0x04240c04, 0x04040414, 0x0414342c, -0x042c2c04, 0x04041c14, 0x041c1424, 0x042c043c, 0x04043c0c, 0x041c2c24, 0x04341c34, 0x040c140c, -0x0424041c, 0x04343c34, 0x040c2c04, 0x0424241c, 0x043c142c, 0x04140c04, 0x042c0414, 0x0404342c, -0x04142404, 0x042c1c14, 0x0c040c24, 0x0c1c043c, 0x0c34340c, 0x0c042c24, 0x0c1c1c34, 0x0c34140c, -0x0c0c041c, 0x0c243c34, 0x0c3c2c04, 0x0c0c241c, 0x0c24142c, 0x0c040404, 0x0c143c14, 0x142c2c2c, -0x14042404, 0x14141414, 0x142c0c24, 0x1404043c, 0x141c340c, 0x14342424, 0x140c1c34, 0x14240c0c, -0x1434041c, 0x140c3434, 0x14242c04, 0x143c1c1c, 0x1414142c, 0x1c2c0404, 0x1c043c14, 0x1c142c2c, -0x1c2c2404, 0x1c041414, 0x1c1c0c24, 0x1c343c3c, 0x1c042c0c, 0x1c1c2424, 0x1c341434, 0x1c0c0c0c, -0x1c24041c, 0x1c3c3434, 0x240c2404, 0x24241c1c, 0x24040c2c, 0x24140404, 0x242c3414, 0x24042c2c, -0x24141c04, 0x242c1414, 0x24040424, 0x241c3c3c, 0x24342c0c, 0x240c2424, 0x241c1434, 0x24340c0c, -0x2c0c041c, 0x2c243434, 0x2c3c2404, 0x2c14141c, 0x2c2c0c2c, 0x2c040404, 0x2c143414, 0x2c2c242c, -0x2c041c04, 0x2c1c0c14, 0x2c340424, 0x2c04343c, 0x2c1c2c0c, 0x2c341c24, 0x340c1434, 0x3424040c, -0x343c3c1c, 0x340c2c34, 0x34242404, 0x3404141c, 0x34140c2c, 0x342c0404, 0x34043414, 0x3414242c, -0x342c1c04, 0x34040c14, 0x341c0424, 0x3c34343c, 0x3c0c240c, 0x3c1c1c24, 0x3c340c34, 0x3c0c040c, -0x3c24341c, 0x3c3c2c34, 0x04141c04, 0x0424141c, 0x0404042c, 0x04143c04, 0x042c2c14, 0x0404242c, -0x041c1404, 0x04340c14, 0x04040424, 0x041c343c, 0x0434240c, 0x040c1c24, 0x04240c34, 0x043c040c, -0x040c341c, 0x04242434, 0x04041c04, 0x04140c1c, 0x042c042c, 0x04043404, 0x0c142c14, 0x0c2c1c2c, -0x0c041404, 0x0c1c0414, 0x0c343c24, 0x0c0c2c3c, 0x0c1c240c, 0x0c341424, 0x0c0c0c34, 0x0c24040c, -0x0c3c341c, 0x0c142434, 0x0c241c04, 0x0c040c1c, 0x1414042c, 0x142c3404, 0x14042c14, 0x141c1c2c, -0x142c1404, 0x14040414, 0x141c3424, 0x14342c3c, 0x140c1c0c, 0x14241424, 0x143c0434, 0x140c3c0c, -0x14242c1c, 0x1c042434, 0x1c141404, 0x1c2c0c1c, 0x1c04042c, 0x1c143404, 0x1c2c2414, 0x1c041c2c, -0x1c1c0c04, 0x1c340414, 0x1c0c3424, 0x1c1c2c3c, 0x1c341c0c, 0x1c0c1424, 0x1c240434, 0x243c3c0c, -0x24142c1c, 0x24241c34, 0x24041404, 0x2414041c, 0x242c3c2c, 0x24042c04, 0x241c2414, 0x242c142c, -0x24040c04, 0x241c0414, 0x24343424, 0x240c243c, 0x24241c0c, 0x2c340c24, 0x2c0c0434, 0x2c24340c, -0x2c3c2c1c, 0x2c141c34, 0x2c2c1404, 0x2c04041c, 0x2c143c2c, 0x2c2c2c04, 0x2c042414, 0x2c1c142c, -0x2c340404, 0x2c0c3c14, 0x341c2c24, 0x3434243c, 0x340c140c, 0x34240c24, 0x343c0434, 0x3414340c, -0x3424241c, 0x34041c34, 0x34140c04, 0x342c041c, 0x3404342c, 0x341c2c04, 0x342c1c14, 0x3404142c, -0x3c1c0404, 0x3c343c14, 0x3c0c2c24, 0x3c24243c, 0x3c34140c, 0x3c0c0c24, 0x3c243c34, 0x043c2c0c, -0x0414241c, 0x042c1434, 0x04040c04, 0x0414041c, 0x042c342c, 0x04042404, 0x041c1c14, 0x04340c2c, -0x040c0404, 0x041c3414, 0x04342c24, 0x040c1c3c, 0x0424140c, 0x043c0424, 0x04143c34, 0x04242c0c, -0x0404241c, 0x04141434, 0x042c0c04, 0x0c04041c, 0x0c1c342c, 0x0c2c2404, 0x0c041414, 0x0c1c0c2c, -0x0c340404, 0x0c0c3414, 0x0c242424, 0x0c341c3c, 0x0c0c0c0c, 0x0c240424, 0x0c3c3434, 0x0c142c0c, -0x0c2c1c1c, 0x14041434, 0x14140404, 0x142c3c1c, 0x14042c2c, 0x141c2404, 0x14341414, 0x14040c2c, -0x141c0404, 0x14343414, 0x140c2424, 0x14241c3c, 0x143c0c0c, 0x14140424, 0x1c243434, 0x1c04240c, -0x1c141c1c, 0x1c2c0c34, 0x1c040404, 0x1c1c341c, 0x1c2c2c2c, 0x1c041c04, 0x1c1c1414, 0x1c34042c, -0x1c0c3c04, 0x1c242c14, 0x1c342424, 0x1c0c143c, 0x24240c0c, 0x243c0424, 0x24143434, 0x242c240c, -0x24041c1c, 0x24140c34, 0x242c0404, 0x2404341c, 0x241c242c, 0x24341c04, 0x24040c14, 0x241c042c, -0x24343404, 0x2c0c2c14, 0x2c241c24, 0x2c3c143c, 0x2c0c040c, 0x2c243c24, 0x2c042c34, 0x2c14240c, -0x2c2c141c, 0x2c040c34, 0x2c1c0404, 0x2c2c341c, 0x2c04242c, 0x2c1c1c04, 0x2c340c14, 0x340c042c, -0x34243404, 0x34342c14, 0x340c1c24, 0x34240c3c, 0x343c040c, 0x34143424, 0x342c2c34, 0x34041c0c, -0x3414141c, 0x342c0434, 0x34043c04, 0x341c2c1c, 0x3434242c, 0x3c041404, 0x3c1c0c14, 0x3c34042c, -0x3c0c3404, 0x3c242414, 0x3c3c1c24, 0x040c0c3c, 0x0424040c, 0x04043424, 0x04142c34, 0x042c1c0c, -0x0404141c, 0x04140434, 0x042c3c04, 0x04042c1c, 0x041c1c2c, 0x04341404, 0x040c0414, 0x041c3c2c, -0x04342c04, 0x040c2414, 0x04241424, 0x043c0c3c, 0x0414040c, 0x042c3424, 0x04042434, 0x04141c0c, -0x0c2c0c1c, 0x0c040434, 0x0c1c3404, 0x0c342c1c, 0x0c041c2c, 0x0c1c1404, 0x0c340414, 0x0c0c3c2c, -0x0c242c04, 0x0c3c2414, 0x0c0c1424, 0x0c24043c, 0x0c043c0c, 0x14142c24, 0x142c2434, 0x1404140c, -0x14140c1c, 0x142c0434, 0x14043404, 0x141c241c, 0x14341c2c, 0x140c0c04, 0x141c0414, 0x1434342c, -0x140c2c04, 0x14241c14, 0x143c1424, 0x1c14043c, 0x1c243c0c, 0x1c042c24, 0x1c142434, 0x1c2c140c, -0x1c040c1c, 0x1c1c3c34, 0x1c342c04, 0x1c04241c, 0x1c1c142c, 0x1c340c04, 0x1c0c0414, 0x1c24342c, -0x1c3c2404, 0x240c1c14, 0x24240c24, 0x2404043c, 0x2414340c, 0x242c2c24, 0x24041c34, 0x2414140c, -0x242c041c, 0x24043c34, 0x241c2c04, 0x2434241c, 0x240c142c, 0x241c0c04, 0x2c340414, 0x2c0c342c, -}; - #define NGRID_IQ2XXS 512 static const uint64_t iq1s_grid[NGRID_IQ2XXS] = { 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000, @@ -10214,6 +10147,9 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v #else + uint32_t aux32[2]; + const uint8_t * grid = (const uint8_t *)aux32; + float sumf = 0.f; for (int i = 0; i < nb; ++i) { const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; @@ -10227,12 +10163,19 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1; int32_t sumi = 0; for (int l = 0; l < 4; ++l) { - const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256))); - for (int j = 0; j < 4; ++j) { - sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); - sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); +#ifdef IQ3S_SLOW_MULT + aux32[0] = (IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256))) & 0x0f0f0f0f; + aux32[1] = (IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256))) & 0x0f0f0f0f; + for (int j = 0; j < 8; ++j) { + sumi += (2*((grid[j]-1)/2) + 1) * q8[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); } +#else + aux32[0] = ((IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256))) & 0x0f0f0f0f) | 0x01010101; + aux32[1] = ((IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256))) & 0x0f0f0f0f) | 0x01010101; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); + } +#endif q8 += 8; } qs += 8; @@ -10240,11 +10183,18 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v bsum += sumi * ls1; sumi = 0; for (int l = 0; l < 4; ++l) { - const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256))); - for (int j = 0; j < 4; ++j) { - sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); - sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); +#ifdef IQ3S_SLOW_MULT + aux32[0] = (IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256))) & 0x0f0f0f0f; + aux32[1] = (IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256))) & 0x0f0f0f0f; + for (int j = 0; j < 8; ++j) { + sumi += (2*((grid[j]-1)/2) + 1) * q8[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); + } +#else + aux32[0] = ((IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256))) & 0x0f0f0f0f) | 0x01010101; + aux32[1] = ((IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256))) & 0x0f0f0f0f) | 0x01010101; +#endif + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); } q8 += 8; } @@ -10254,7 +10204,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v } sumf += d * bsum; } - *s = 0.25f * sumf; + *s = sumf; #endif } From b587482287014e5440da047376858d6af88a6340 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 4 Mar 2024 19:43:22 +0200 Subject: [PATCH 19/24] iq3_s_mult_shuffle: mult + shuffle based codebook --- ggml-cuda.cu | 53 +++++++++++++++++++++--------------- ggml-quants.c | 74 ++++++++++++++++++++++++++++++++++----------------- 2 files changed, 81 insertions(+), 46 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index ff721ea43..6f8c4a3ac 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2375,8 +2375,12 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds // Better (lower PPL), but requires more bit twidling, so slower #define IQ3S_MULTIPLIER 190842953LL #else -#define IQ3S_MULTIPLIER 898886 +#define IQ3S_MULTIPLIER 72968561ULL +//#define IQ3S_MULTIPLIER 540201 +//#define IQ3S_MULTIPLIER 1378231 +//#define IQ3S_MULTIPLIER 898886 //#define IQ3S_MULTIPLIER 842866 +static const __device__ uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15}; #endif template @@ -2400,32 +2404,36 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_ aux32[0] = ((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; aux32[1] = ((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; #else - aux32[0] = (((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; - aux32[1] = (((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + //aux32[0] = (((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + //aux32[1] = (((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + aux32[0] = (((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); + aux32[1] = (((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); #endif -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics -#ifdef IQ3S_SLOW_MULT - aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101; - aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101; -#endif - uint32_t signs0 = __vcmpeq4(((signs & 0xf) * 0x01010101) & 0x08040201, 0x08040201); - uint32_t signs1 = __vcmpeq4(((signs >> 4) * 0x01010101) & 0x08040201, 0x08040201); - aux32[0] = __vsub4(aux32[0] ^ signs0, signs0); - aux32[1] = __vsub4(aux32[1] ^ signs1, signs1); - for (int j = 0; j < 8; ++j) { - y[j] = d * grid[j]; - } -#else +//#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics +//#ifdef IQ3S_SLOW_MULT +// aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101; +// aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101; +//#endif +// uint32_t signs0 = __vcmpeq4(((signs & 0xf) * 0x01010101) & 0x08040201, 0x08040201); +// uint32_t signs1 = __vcmpeq4(((signs >> 4) * 0x01010101) & 0x08040201, 0x08040201); +// aux32[0] = __vsub4(aux32[0] ^ signs0, signs0); +// aux32[1] = __vsub4(aux32[1] ^ signs1, signs1); +// for (int j = 0; j < 8; ++j) { +// y[j] = d * grid[j]; +// } +//#else #ifdef IQ3S_SLOW_MULT for (int j = 0; j < 8; ++j) { y[j] = d * (2*((grid[j]-1)/2) + 1) * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } #else + //static const uint8_t k_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 7, 7, 9, 9, 11, 11, 13, 15}; for (int j = 0; j < 8; ++j) { - y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + //y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + y[j] = d * iq3s_values[grid[j]] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } #endif -#endif +//#endif #else assert(false); #endif @@ -5225,7 +5233,6 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( #endif } -// TODO: don't use lookup table for signs static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics @@ -5233,6 +5240,7 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( const block_iq3_s * bq2 = (const block_iq3_s *) vbq; uint32_t aux32[2]; + uint8_t * aux8 = (uint8_t *)aux32; const int ib32 = iqs; const uint8_t * qs = bq2->qs + 8*ib32; @@ -5249,8 +5257,11 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0) >> 1) & 0x07070707) << 1) | 0x01010101; aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0) >> 1) & 0x07070707) << 1) | 0x01010101; #else - aux32[0] = (((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; - aux32[1] = (((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + //aux32[0] = (((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + //aux32[1] = (((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + aux32[0] = (((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); + aux32[1] = (((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); + for (int j = 0; j < 8; ++j) aux8[j] = iq3s_values[aux8[j]]; #endif uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201); uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201); diff --git a/ggml-quants.c b/ggml-quants.c index 70af6e623..0535f5c25 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -4058,12 +4058,18 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y // Best PPL #define IQ3S_MULTIPLIER 190842953 #else -#define IQ3S_MULTIPLIER 898886 +#define IQ3S_MULTIPLIER 72968561ULL +//#define IQ3S_MULTIPLIER 540201 +//#define IQ3S_MULTIPLIER 1378231 +//#define IQ3S_MULTIPLIER 898886 //#define IQ3S_MULTIPLIER 842866 #endif #define IQ3S_BITS 3 +static const uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15}; +//static const uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 7, 7, 9, 9, 11, 11, 13, 15}; + void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -4099,10 +4105,15 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in y[j] = dl * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); } #else - aux32[0] = (((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; - aux32[1] = (((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + //aux32[0] = (((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + //aux32[1] = (((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + //for (int j = 0; j < 8; ++j) { + // y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); + //} + aux32[0] = (((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); + aux32[1] = (((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); for (int j = 0; j < 8; ++j) { - y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); + y[j] = dl * iq3s_values[grid[j]] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); } #endif y += 8; @@ -4118,12 +4129,17 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in y[j] = dl * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); } #else - aux32[0] = (((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; - aux32[1] = (((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; -#endif + //aux32[0] = (((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + //aux32[1] = (((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + //for (int j = 0; j < 8; ++j) { + // y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); + //} + aux32[0] = (((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); + aux32[1] = (((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); for (int j = 0; j < 8; ++j) { - y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); + y[j] = dl * iq3s_values[grid[j]] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); } +#endif y += 8; } qh += 2; @@ -10073,12 +10089,13 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1); const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2); + const __m128i shuffle128 = _mm_loadu_si128((const __m128i *)iq3s_values); + const __m256i shuffle = _mm256_set_m128i(shuffle128, shuffle128); - const __m256i idx_mask = _mm256_set1_epi32(256); - const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8); const __m256i idx_mult = _mm256_set1_epi32(IQ3S_MULTIPLIER); - const __m256i m1 = _mm256_set1_epi8(1); + //const __m256i m1 = _mm256_set1_epi8(1); const __m256i m15 = _mm256_set1_epi32(0x0f0f0f0f); + const __m256i m100 = _mm256_set1_epi32(0x0100); #ifdef IQ3S_SLOW_MULT const __m256i m7 = _mm256_set1_epi32(0x07070707); const __m256i m0 = _mm256_setzero_si256(); @@ -10096,12 +10113,19 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m128i idx_l_8 = _mm_loadu_si128((const __m128i*)qs); qs += 16; - const __m256i idx_l_16 = _mm256_cvtepu8_epi16(idx_l_8); - const __m256i idx_h_l = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[ib32+0]), idx_shift), idx_mask); - const __m256i idx_h_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[ib32+1]), idx_shift), idx_mask); - const __m256i idx_32_l = _mm256_or_si256(idx_h_l, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l_16))); - const __m256i idx_32_h = _mm256_or_si256(idx_h_h, _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l_16, 1))); + + const __m256i q3_low_bytes_1 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)qs)); qs += 8; + const __m256i q3_low_bytes_2 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)qs)); qs += 8; + uint64_t high_bits_spread_1 = ((uint64_t)qh[ib32+0] * 0x0101010101010101ULL) & 0x8040201008040201ULL; + uint64_t high_bits_spread_2 = ((uint64_t)qh[ib32+1] * 0x0101010101010101ULL) & 0x8040201008040201ULL; + const __m256i high_bits_in_low_1 = _mm256_cmpgt_epi32( + _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)&high_bits_spread_1)), + _mm256_setzero_si256()); + const __m256i high_bits_in_low_2 = _mm256_cmpgt_epi32( + _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)&high_bits_spread_2)), + _mm256_setzero_si256()); + const __m256i idx_32_l = _mm256_or_si256(_mm256_and_si256(m100, high_bits_in_low_1), q3_low_bytes_1); + const __m256i idx_32_h = _mm256_or_si256(_mm256_and_si256(m100, high_bits_in_low_2), q3_low_bytes_2); #ifdef IQ3S_SLOW_MULT const __m256i idx_l = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1), m0); @@ -10109,12 +10133,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const __m256i idx_h = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1), m0); const __m256i q2_2 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_h, 1), m7), 1), m1); #else - //const __m256i idx_l = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1); - //const __m256i q2_1 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_l, 1), m7), 1), m1); - //const __m256i idx_h = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1); - //const __m256i q2_2 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_h, 1), m7), 1), m1); - const __m256i q2_1 = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1); - const __m256i q2_2 = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1); + const __m256i q2_1 = _mm256_shuffle_epi8(shuffle, _mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15)); + const __m256i q2_2 = _mm256_shuffle_epi8(shuffle, _mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15)); #endif __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16)); @@ -11364,10 +11384,14 @@ static void iq3xs_init_grid512(void) { #ifdef IQ3S_SLOW_MULT aux32 = ((uint64_t)IQ3S_MULTIPLIER * k) & 0x0f0f0f0f; #else - aux32 = (((uint64_t)IQ3S_MULTIPLIER * k) & 0x0f0f0f0f) | 0x01010101; + //aux32 = (((uint64_t)IQ3S_MULTIPLIER * k) & 0x0f0f0f0f) | 0x01010101; + aux32 = ((k * IQ3S_MULTIPLIER) & 0x0f0f0f0f); #endif + //for (int i = 0; i < 4; ++i) { + // pos[i] = 2*((q4[i]-1)/2) + 1; + //} for (int i = 0; i < 4; ++i) { - pos[i] = 2*((q4[i]-1)/2) + 1; + pos[i] = iq3s_values[q4[i]]; } } From a6a263b919a89cd5523286c1c09f630529967788 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 4 Mar 2024 20:10:36 +0200 Subject: [PATCH 20/24] iq3_s_mult_shuffle: works on ARM_NEON and Metal --- ggml-metal.metal | 28 ++++++++++++++++------------ ggml-quants.c | 9 +++++---- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 69a928c24..550bc682e 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2550,7 +2550,9 @@ typedef struct { #define IQ3S_MULTIPLIER 190842953 #else //#define IQ3S_MULTIPLIER 898886 -#define IQ3S_MULTIPLIER 842866 +//#define IQ3S_MULTIPLIER 842866 +#define IQ3S_MULTIPLIER 72968561ULL +constexpr constant static uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15}; #endif typedef struct { @@ -4698,9 +4700,9 @@ void kernel_mul_mv_iq3_s_f32_impl( { int nval = 8; int pos = (32*sgitg + tiisg)*nval; -#ifdef IQ3S_SLOW_MULT uint32_t aux32; thread int8_t * q = (thread int8_t *)&aux32; +#ifdef IQ3S_SLOW_MULT for (int i = 0; i < nval; ++i) { aux32 = (IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f; for (int k = 0; k < 4; ++k) q[k] = 2*((q[k]-1)/2) + 1; @@ -4708,7 +4710,9 @@ void kernel_mul_mv_iq3_s_f32_impl( } #else for (int i = 0; i < nval; ++i) { - values[pos + i] = ((IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f) | 0x01010101; + aux32 = (IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f; + for (int k = 0; k < 4; ++k) q[k] = iq3s_values[q[k]]; + values[pos + i] = aux32; } #endif threadgroup_barrier(mem_flags::mem_threadgroup); @@ -5667,7 +5671,7 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)); uint32_t aux32[2]; thread const int8_t * grid = (thread const int8_t *)aux32; -#ifdef IQ3S_SLOW)MULT +#ifdef IQ3S_SLOW_MULT aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f; aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f; for (int i = 0; i < 4; ++i) { @@ -5681,17 +5685,17 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg[3][i] = dl * (2*((grid[i+4]-1)/2)+1) * select(1, -1, signs[1] & kmask_iq2xs[i+4]); } #else - aux32[0] = ((IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f) | 0x01010101; - aux32[1] = ((IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f) | 0x01010101; + aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f; + aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f; for (int i = 0; i < 4; ++i) { - reg[0][i] = dl * grid[i+0] * select(1, -1, signs[0] & kmask_iq2xs[i+0]); - reg[1][i] = dl * grid[i+4] * select(1, -1, signs[0] & kmask_iq2xs[i+4]); + reg[0][i] = dl * iq3s_values[grid[i+0]] * select(1, -1, signs[0] & kmask_iq2xs[i+0]); + reg[1][i] = dl * iq3s_values[grid[i+4]] * select(1, -1, signs[0] & kmask_iq2xs[i+4]); } - aux32[0] = ((IQ3S_MULTIPLIER * (qs[4*il+2] | ((qh << 6) & 256))) & 0x0f0f0f0f) | 0x01010101; - aux32[1] = ((IQ3S_MULTIPLIER * (qs[4*il+3] | ((qh << 5) & 256))) & 0x0f0f0f0f) | 0x01010101; + aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+2] | ((qh << 6) & 256))) & 0x0f0f0f0f; + aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+3] | ((qh << 5) & 256))) & 0x0f0f0f0f; for (int i = 0; i < 4; ++i) { - reg[2][i] = dl * grid[i+0] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); - reg[3][i] = dl * grid[i+4] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); + reg[2][i] = dl * iq3s_values[grid[i+0]] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); + reg[3][i] = dl * iq3s_values[grid[i+4]] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); } #endif } diff --git a/ggml-quants.c b/ggml-quants.c index 0535f5c25..e7fc0f85c 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -10005,6 +10005,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const uint8x16x2_t mask1 = vld1q_u8_x2(k_mask1); const uint8x16_t mask2 = vld1q_u8(k_mask2); + const uint8x16_t shuff = vld1q_u8(iq3s_values); const uint32x4_t idx_mult = vdupq_n_u32(IQ3S_MULTIPLIER); const int16x8_t idx_shift = vld1q_s16(k_shift); @@ -10042,10 +10043,10 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v q3s.val[2] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[2], m1), m0), 1), 1), m1); q3s.val[3] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[3], m1), m0), 1), 1), m1); #else - q3s.val[0] = vorrq_s8(vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_1))), idx_mask2)), m1); - q3s.val[1] = vorrq_s8(vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_1))), idx_mask2)), m1); - q3s.val[2] = vorrq_s8(vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_2))), idx_mask2)), m1); - q3s.val[3] = vorrq_s8(vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_2))), idx_mask2)), m1); + q3s.val[0] = vqtbl1q_s8(shuff, vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_1))), idx_mask2))); + q3s.val[1] = vqtbl1q_s8(shuff, vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_1))), idx_mask2))); + q3s.val[2] = vqtbl1q_s8(shuff, vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_2))), idx_mask2))); + q3s.val[3] = vqtbl1q_s8(shuff, vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_2))), idx_mask2))); #endif vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16))); From b1d753be34825503bb62bd30ffcf1dc26e3b91af Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 5 Mar 2024 08:23:37 +0200 Subject: [PATCH 21/24] iq3_s_mult: remove SLOW_MULT option --- ggml-metal.metal | 29 --------------- ggml-quants.c | 92 ++---------------------------------------------- 2 files changed, 2 insertions(+), 119 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 550bc682e..176287fcd 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2546,14 +2546,8 @@ typedef struct { uint8_t signs[QK_K/8]; uint8_t scales[IQ3S_N_SCALE]; } block_iq3_s; -#ifdef IQ3S_SLOW_MULT -#define IQ3S_MULTIPLIER 190842953 -#else -//#define IQ3S_MULTIPLIER 898886 -//#define IQ3S_MULTIPLIER 842866 #define IQ3S_MULTIPLIER 72968561ULL constexpr constant static uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15}; -#endif typedef struct { half d; @@ -4702,19 +4696,11 @@ void kernel_mul_mv_iq3_s_f32_impl( int pos = (32*sgitg + tiisg)*nval; uint32_t aux32; thread int8_t * q = (thread int8_t *)&aux32; -#ifdef IQ3S_SLOW_MULT - for (int i = 0; i < nval; ++i) { - aux32 = (IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f; - for (int k = 0; k < 4; ++k) q[k] = 2*((q[k]-1)/2) + 1; - values[pos + i] = aux32; - } -#else for (int i = 0; i < nval; ++i) { aux32 = (IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f; for (int k = 0; k < 4; ++k) q[k] = iq3s_values[q[k]]; values[pos + i] = aux32; } -#endif threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -5671,20 +5657,6 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)); uint32_t aux32[2]; thread const int8_t * grid = (thread const int8_t *)aux32; -#ifdef IQ3S_SLOW_MULT - aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f; - aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f; - for (int i = 0; i < 4; ++i) { - reg[0][i] = dl * (2*((grid[i+0]-1)/2)+1) * select(1, -1, signs[0] & kmask_iq2xs[i+0]); - reg[1][i] = dl * (2*((grid[i+4]-1)/2)+1) * select(1, -1, signs[0] & kmask_iq2xs[i+4]); - } - aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+2] | ((qh << 6) & 256))) & 0x0f0f0f0f; - aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+3] | ((qh << 5) & 256))) & 0x0f0f0f0f; - for (int i = 0; i < 4; ++i) { - reg[2][i] = dl * (2*((grid[i+0]-1)/2)+1) * select(1, -1, signs[1] & kmask_iq2xs[i+0]); - reg[3][i] = dl * (2*((grid[i+4]-1)/2)+1) * select(1, -1, signs[1] & kmask_iq2xs[i+4]); - } -#else aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f; aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f; for (int i = 0; i < 4; ++i) { @@ -5697,7 +5669,6 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg[2][i] = dl * iq3s_values[grid[i+0]] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); reg[3][i] = dl * iq3s_values[grid[i+4]] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); } -#endif } template diff --git a/ggml-quants.c b/ggml-quants.c index e7fc0f85c..ef21bb487 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -4054,21 +4054,11 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y // ====================== 3.3125 bpw (de)-quantization -#ifdef IQ3S_SLOW_MULT -// Best PPL -#define IQ3S_MULTIPLIER 190842953 -#else #define IQ3S_MULTIPLIER 72968561ULL -//#define IQ3S_MULTIPLIER 540201 -//#define IQ3S_MULTIPLIER 1378231 -//#define IQ3S_MULTIPLIER 898886 -//#define IQ3S_MULTIPLIER 842866 -#endif #define IQ3S_BITS 3 static const uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15}; -//static const uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 7, 7, 9, 9, 11, 11, 13, 15}; void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int k) { assert(k % QK_K == 0); @@ -4098,48 +4088,22 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in #endif for (int l = 0; l < 4; ++l) { const float dl = db[8*l/IQ3S_BLOCK_SIZE]; -#ifdef IQ3S_SLOW_MULT - aux32[0] = ((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; - aux32[1] = ((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; - for (int j = 0; j < 8; ++j) { - y[j] = dl * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); - } -#else - //aux32[0] = (((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; - //aux32[1] = (((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; - //for (int j = 0; j < 8; ++j) { - // y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); - //} aux32[0] = (((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); aux32[1] = (((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); for (int j = 0; j < 8; ++j) { y[j] = dl * iq3s_values[grid[j]] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); } -#endif y += 8; } qs += 8; signs += 4; for (int l = 0; l < 4; ++l) { const float dl = db[(8*l+32)/IQ3S_BLOCK_SIZE]; -#ifdef IQ3S_SLOW_MULT - aux32[0] = ((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; - aux32[1] = ((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; - for (int j = 0; j < 8; ++j) { - y[j] = dl * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); - } -#else - //aux32[0] = (((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; - //aux32[1] = (((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; - //for (int j = 0; j < 8; ++j) { - // y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); - //} aux32[0] = (((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); aux32[1] = (((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); for (int j = 0; j < 8; ++j) { y[j] = dl * iq3s_values[grid[j]] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); } -#endif y += 8; } qh += 2; @@ -10005,14 +9969,13 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const uint8x16x2_t mask1 = vld1q_u8_x2(k_mask1); const uint8x16_t mask2 = vld1q_u8(k_mask2); - const uint8x16_t shuff = vld1q_u8(iq3s_values); + const int8x16_t shuff = vld1q_s8((const int8_t *)iq3s_values); const uint32x4_t idx_mult = vdupq_n_u32(IQ3S_MULTIPLIER); const int16x8_t idx_shift = vld1q_s16(k_shift); const uint16x8_t idx_mask1 = vdupq_n_u16(256); const uint32x4_t idx_mask2 = vdupq_n_u32(0x0f0f0f0f); const int8x16_t m1 = vdupq_n_s8(1); - const int8x16_t m0 = vdupq_n_s8(0); uint8x16x2_t vs; ggml_int8x16x4_t q3s; @@ -10033,21 +9996,10 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v vmovl_u8(vget_low_u8(idx_l))); const uint16x8_t idx_2 = vorrq_u16(vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), idx_shift), idx_mask1), vmovl_u8(vget_high_u8(idx_l))); -#ifdef IQ3S_SLOW_MULT - q3s.val[0] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_1))), idx_mask2)); - q3s.val[1] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_1))), idx_mask2)); - q3s.val[2] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_2))), idx_mask2)); - q3s.val[3] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_2))), idx_mask2)); - q3s.val[0] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[0], m1), m0), 1), 1), m1); - q3s.val[1] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[1], m1), m0), 1), 1), m1); - q3s.val[2] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[2], m1), m0), 1), 1), m1); - q3s.val[3] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[3], m1), m0), 1), 1), m1); -#else q3s.val[0] = vqtbl1q_s8(shuff, vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_1))), idx_mask2))); q3s.val[1] = vqtbl1q_s8(shuff, vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_1))), idx_mask2))); q3s.val[2] = vqtbl1q_s8(shuff, vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_2))), idx_mask2))); q3s.val[3] = vqtbl1q_s8(shuff, vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_2))), idx_mask2))); -#endif vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16))); vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); @@ -10094,13 +10046,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const __m256i shuffle = _mm256_set_m128i(shuffle128, shuffle128); const __m256i idx_mult = _mm256_set1_epi32(IQ3S_MULTIPLIER); - //const __m256i m1 = _mm256_set1_epi8(1); const __m256i m15 = _mm256_set1_epi32(0x0f0f0f0f); const __m256i m100 = _mm256_set1_epi32(0x0100); -#ifdef IQ3S_SLOW_MULT - const __m256i m7 = _mm256_set1_epi32(0x07070707); - const __m256i m0 = _mm256_setzero_si256(); -#endif __m256 accumf = _mm256_setzero_ps(); for (int i = 0; i < nb; ++i) { @@ -10128,15 +10075,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const __m256i idx_32_l = _mm256_or_si256(_mm256_and_si256(m100, high_bits_in_low_1), q3_low_bytes_1); const __m256i idx_32_h = _mm256_or_si256(_mm256_and_si256(m100, high_bits_in_low_2), q3_low_bytes_2); -#ifdef IQ3S_SLOW_MULT - const __m256i idx_l = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1), m0); - const __m256i q2_1 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_l, 1), m7), 1), m1); - const __m256i idx_h = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1), m0); - const __m256i q2_2 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_h, 1), m7), 1), m1); -#else const __m256i q2_1 = _mm256_shuffle_epi8(shuffle, _mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15)); const __m256i q2_2 = _mm256_shuffle_epi8(shuffle, _mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15)); -#endif __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16)); aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); @@ -10184,19 +10124,11 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1; int32_t sumi = 0; for (int l = 0; l < 4; ++l) { -#ifdef IQ3S_SLOW_MULT aux32[0] = (IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256))) & 0x0f0f0f0f; aux32[1] = (IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256))) & 0x0f0f0f0f; - for (int j = 0; j < 8; ++j) { - sumi += (2*((grid[j]-1)/2) + 1) * q8[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); - } -#else - aux32[0] = ((IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256))) & 0x0f0f0f0f) | 0x01010101; - aux32[1] = ((IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256))) & 0x0f0f0f0f) | 0x01010101; for (int j = 0; j < 8; ++j) { sumi += grid[j] * q8[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); } -#endif q8 += 8; } qs += 8; @@ -10204,18 +10136,10 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v bsum += sumi * ls1; sumi = 0; for (int l = 0; l < 4; ++l) { -#ifdef IQ3S_SLOW_MULT aux32[0] = (IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256))) & 0x0f0f0f0f; aux32[1] = (IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256))) & 0x0f0f0f0f; for (int j = 0; j < 8; ++j) { - sumi += (2*((grid[j]-1)/2) + 1) * q8[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); - } -#else - aux32[0] = ((IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256))) & 0x0f0f0f0f) | 0x01010101; - aux32[1] = ((IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256))) & 0x0f0f0f0f) | 0x01010101; -#endif - for (int j = 0; j < 8; ++j) { - sumi += grid[j] * q8[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); + sumi += q8[j] * iq3s_values[grid[j]] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); } q8 += 8; } @@ -11382,15 +11306,7 @@ static void iq3xs_init_grid512(void) { const uint8_t * q4 = (const uint8_t *)&aux32; for (int k = 0; k < grid_size; ++k) { int8_t * pos = (int8_t *)(the_grid + k); -#ifdef IQ3S_SLOW_MULT - aux32 = ((uint64_t)IQ3S_MULTIPLIER * k) & 0x0f0f0f0f; -#else - //aux32 = (((uint64_t)IQ3S_MULTIPLIER * k) & 0x0f0f0f0f) | 0x01010101; aux32 = ((k * IQ3S_MULTIPLIER) & 0x0f0f0f0f); -#endif - //for (int i = 0; i < 4; ++i) { - // pos[i] = 2*((q4[i]-1)/2) + 1; - //} for (int i = 0; i < 4; ++i) { pos[i] = iq3s_values[q4[i]]; } @@ -12080,11 +11996,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo } float d = max_scale/31; -#ifdef IQ3S_SLOW_MULT - y[ibl].d = GGML_FP32_TO_FP16(d * 1.025f); -#else y[ibl].d = GGML_FP32_TO_FP16(d * 1.030f); -#endif float id = 1/d; for (int ib = 0; ib < QK_K/block_size; ib += 2) { int l1 = nearest_int(0.5f*(id*scales[ib+0]-1)); From 6d15da1ec08542af9c71599a1d1bff2066786b47 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 5 Mar 2024 08:36:57 +0200 Subject: [PATCH 22/24] iq3_s_mult_shuffle: use new multiplier and cleanup --- ggml-cuda.cu | 49 +----------------------------------------------- ggml-metal.metal | 2 +- ggml-quants.c | 3 +-- 3 files changed, 3 insertions(+), 51 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 6f8c4a3ac..373f03a23 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2371,17 +2371,8 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds } -#ifdef IQ3S_SLOW_MULT -// Better (lower PPL), but requires more bit twidling, so slower -#define IQ3S_MULTIPLIER 190842953LL -#else -#define IQ3S_MULTIPLIER 72968561ULL -//#define IQ3S_MULTIPLIER 540201 -//#define IQ3S_MULTIPLIER 1378231 -//#define IQ3S_MULTIPLIER 898886 -//#define IQ3S_MULTIPLIER 842866 +#define IQ3S_MULTIPLIER 518559 static const __device__ uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15}; -#endif template static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { @@ -2400,40 +2391,11 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_ const int is = (32*ib + 8*il)/IQ3S_BLOCK_SIZE; const float d = (float)x[i].d * (1 + 2*((x[i].scales[is/2] >> 4*(is%2)) & 0xf)); const uint8_t signs = x[i].signs[4*ib + il]; -#ifdef IQ3S_SLOW_MULT - aux32[0] = ((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; - aux32[1] = ((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; -#else - //aux32[0] = (((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; - //aux32[1] = (((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; aux32[0] = (((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); aux32[1] = (((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); -#endif -//#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics -//#ifdef IQ3S_SLOW_MULT -// aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101; -// aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101; -//#endif -// uint32_t signs0 = __vcmpeq4(((signs & 0xf) * 0x01010101) & 0x08040201, 0x08040201); -// uint32_t signs1 = __vcmpeq4(((signs >> 4) * 0x01010101) & 0x08040201, 0x08040201); -// aux32[0] = __vsub4(aux32[0] ^ signs0, signs0); -// aux32[1] = __vsub4(aux32[1] ^ signs1, signs1); -// for (int j = 0; j < 8; ++j) { -// y[j] = d * grid[j]; -// } -//#else -#ifdef IQ3S_SLOW_MULT for (int j = 0; j < 8; ++j) { - y[j] = d * (2*((grid[j]-1)/2) + 1) * (signs & kmask_iq2xs[j] ? -1.f : 1.f); - } -#else - //static const uint8_t k_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 7, 7, 9, 9, 11, 11, 13, 15}; - for (int j = 0; j < 8; ++j) { - //y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); y[j] = d * iq3s_values[grid[j]] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } -#endif -//#endif #else assert(false); #endif @@ -5251,18 +5213,9 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( int sumi[2] = {0, 0}; #endif for (int l = 0; l < 4; ++l) { -#ifdef IQ3S_SLOW_MULT - aux32[0] = ((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; - aux32[1] = ((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; - aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0) >> 1) & 0x07070707) << 1) | 0x01010101; - aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0) >> 1) & 0x07070707) << 1) | 0x01010101; -#else - //aux32[0] = (((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; - //aux32[1] = (((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; aux32[0] = (((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); aux32[1] = (((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); for (int j = 0; j < 8; ++j) aux8[j] = iq3s_values[aux8[j]]; -#endif uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201); uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201); const int grid_l = __vsub4(aux32[0] ^ signs0, signs0); diff --git a/ggml-metal.metal b/ggml-metal.metal index 176287fcd..8c3ac9e34 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2546,7 +2546,7 @@ typedef struct { uint8_t signs[QK_K/8]; uint8_t scales[IQ3S_N_SCALE]; } block_iq3_s; -#define IQ3S_MULTIPLIER 72968561ULL +#define IQ3S_MULTIPLIER 518559 constexpr constant static uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15}; typedef struct { diff --git a/ggml-quants.c b/ggml-quants.c index ef21bb487..c83b7c775 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -4054,8 +4054,7 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y // ====================== 3.3125 bpw (de)-quantization -#define IQ3S_MULTIPLIER 72968561ULL - +#define IQ3S_MULTIPLIER 518559 #define IQ3S_BITS 3 static const uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15}; From 93034df760c2ff90f5bc68df8fbdc6c90183ab18 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 5 Mar 2024 10:06:07 +0200 Subject: [PATCH 23/24] iq3_s_mult_shuffle: use lookup table on CUDA ~4% faster TG that way. --- ggml-cuda.cu | 171 +++++++++++++++++++++++++++------------------------ 1 file changed, 90 insertions(+), 81 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 373f03a23..785a7a3f9 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2009,72 +2009,72 @@ static const __device__ uint32_t iq3xxs_grid[256] = { 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, }; -//static const __device__ uint32_t iq3xs_grid[512] = { -//0x04040404, 0x04142c14, 0x042c2424, 0x0404143c, 0x04140c0c, 0x042c0424, 0x04043434, 0x041c240c, -//0x04341c1c, 0x040c0c34, 0x041c0404, 0x0434341c, 0x040c2c2c, 0x04241c04, 0x043c1414, 0x0414042c, -//0x04243c04, 0x04042c14, 0x04142424, 0x042c143c, 0x04040c0c, 0x0c1c0424, 0x0c2c3434, 0x0c04240c, -//0x0c1c141c, 0x0c340c34, 0x0c0c0404, 0x0c24341c, 0x0c34242c, 0x0c0c1c04, 0x0c240c14, 0x0c3c042c, -//0x0c143404, 0x0c2c2c14, 0x14041c24, 0x1414143c, 0x142c040c, 0x14043c24, 0x141c2c34, 0x1434240c, -//0x140c141c, 0x141c0c34, 0x14340404, 0x140c341c, 0x1424242c, 0x143c1c04, 0x14140c14, 0x1424042c, -//0x1c043404, 0x1c142414, 0x1c2c1c24, 0x1c040c3c, 0x1c1c040c, 0x1c2c3424, 0x1c042c34, 0x1c1c1c0c, -//0x1c34141c, 0x1c0c0434, 0x1c243c04, 0x1c342c1c, 0x1c0c242c, 0x1c241404, 0x243c0c14, 0x2414042c, -//0x242c3404, 0x24042414, 0x24141c24, 0x242c0c3c, 0x2404040c, 0x241c3424, 0x24342434, 0x24041c0c, -//0x241c0c1c, 0x24340434, 0x240c3404, 0x2c242c1c, 0x2c3c1c2c, 0x2c141404, 0x2c240414, 0x2c043c2c, -//0x2c142c04, 0x2c2c2414, 0x2c041424, 0x2c1c0c3c, 0x2c2c040c, 0x2c043424, 0x2c1c2434, 0x2c341c0c, -//0x2c0c0c1c, 0x34240434, 0x34343404, 0x340c2c1c, 0x34241c2c, 0x343c1404, 0x34140414, 0x342c342c, -//0x34042c04, 0x34141c14, 0x342c1424, 0x3404043c, 0x341c3c0c, 0x34342c24, 0x3c042434, 0x3c1c140c, -//0x3c340c1c, 0x3c0c0434, 0x3c243404, 0x3c3c241c, 0x3c0c1c2c, 0x04240c04, 0x04040414, 0x0414342c, -//0x042c2c04, 0x04041c14, 0x041c1424, 0x042c043c, 0x04043c0c, 0x041c2c24, 0x04341c34, 0x040c140c, -//0x0424041c, 0x04343c34, 0x040c2c04, 0x0424241c, 0x043c142c, 0x04140c04, 0x042c0414, 0x0404342c, -//0x04142404, 0x042c1c14, 0x0c040c24, 0x0c1c043c, 0x0c34340c, 0x0c042c24, 0x0c1c1c34, 0x0c34140c, -//0x0c0c041c, 0x0c243c34, 0x0c3c2c04, 0x0c0c241c, 0x0c24142c, 0x0c040404, 0x0c143c14, 0x142c2c2c, -//0x14042404, 0x14141414, 0x142c0c24, 0x1404043c, 0x141c340c, 0x14342424, 0x140c1c34, 0x14240c0c, -//0x1434041c, 0x140c3434, 0x14242c04, 0x143c1c1c, 0x1414142c, 0x1c2c0404, 0x1c043c14, 0x1c142c2c, -//0x1c2c2404, 0x1c041414, 0x1c1c0c24, 0x1c343c3c, 0x1c042c0c, 0x1c1c2424, 0x1c341434, 0x1c0c0c0c, -//0x1c24041c, 0x1c3c3434, 0x240c2404, 0x24241c1c, 0x24040c2c, 0x24140404, 0x242c3414, 0x24042c2c, -//0x24141c04, 0x242c1414, 0x24040424, 0x241c3c3c, 0x24342c0c, 0x240c2424, 0x241c1434, 0x24340c0c, -//0x2c0c041c, 0x2c243434, 0x2c3c2404, 0x2c14141c, 0x2c2c0c2c, 0x2c040404, 0x2c143414, 0x2c2c242c, -//0x2c041c04, 0x2c1c0c14, 0x2c340424, 0x2c04343c, 0x2c1c2c0c, 0x2c341c24, 0x340c1434, 0x3424040c, -//0x343c3c1c, 0x340c2c34, 0x34242404, 0x3404141c, 0x34140c2c, 0x342c0404, 0x34043414, 0x3414242c, -//0x342c1c04, 0x34040c14, 0x341c0424, 0x3c34343c, 0x3c0c240c, 0x3c1c1c24, 0x3c340c34, 0x3c0c040c, -//0x3c24341c, 0x3c3c2c34, 0x04141c04, 0x0424141c, 0x0404042c, 0x04143c04, 0x042c2c14, 0x0404242c, -//0x041c1404, 0x04340c14, 0x04040424, 0x041c343c, 0x0434240c, 0x040c1c24, 0x04240c34, 0x043c040c, -//0x040c341c, 0x04242434, 0x04041c04, 0x04140c1c, 0x042c042c, 0x04043404, 0x0c142c14, 0x0c2c1c2c, -//0x0c041404, 0x0c1c0414, 0x0c343c24, 0x0c0c2c3c, 0x0c1c240c, 0x0c341424, 0x0c0c0c34, 0x0c24040c, -//0x0c3c341c, 0x0c142434, 0x0c241c04, 0x0c040c1c, 0x1414042c, 0x142c3404, 0x14042c14, 0x141c1c2c, -//0x142c1404, 0x14040414, 0x141c3424, 0x14342c3c, 0x140c1c0c, 0x14241424, 0x143c0434, 0x140c3c0c, -//0x14242c1c, 0x1c042434, 0x1c141404, 0x1c2c0c1c, 0x1c04042c, 0x1c143404, 0x1c2c2414, 0x1c041c2c, -//0x1c1c0c04, 0x1c340414, 0x1c0c3424, 0x1c1c2c3c, 0x1c341c0c, 0x1c0c1424, 0x1c240434, 0x243c3c0c, -//0x24142c1c, 0x24241c34, 0x24041404, 0x2414041c, 0x242c3c2c, 0x24042c04, 0x241c2414, 0x242c142c, -//0x24040c04, 0x241c0414, 0x24343424, 0x240c243c, 0x24241c0c, 0x2c340c24, 0x2c0c0434, 0x2c24340c, -//0x2c3c2c1c, 0x2c141c34, 0x2c2c1404, 0x2c04041c, 0x2c143c2c, 0x2c2c2c04, 0x2c042414, 0x2c1c142c, -//0x2c340404, 0x2c0c3c14, 0x341c2c24, 0x3434243c, 0x340c140c, 0x34240c24, 0x343c0434, 0x3414340c, -//0x3424241c, 0x34041c34, 0x34140c04, 0x342c041c, 0x3404342c, 0x341c2c04, 0x342c1c14, 0x3404142c, -//0x3c1c0404, 0x3c343c14, 0x3c0c2c24, 0x3c24243c, 0x3c34140c, 0x3c0c0c24, 0x3c243c34, 0x043c2c0c, -//0x0414241c, 0x042c1434, 0x04040c04, 0x0414041c, 0x042c342c, 0x04042404, 0x041c1c14, 0x04340c2c, -//0x040c0404, 0x041c3414, 0x04342c24, 0x040c1c3c, 0x0424140c, 0x043c0424, 0x04143c34, 0x04242c0c, -//0x0404241c, 0x04141434, 0x042c0c04, 0x0c04041c, 0x0c1c342c, 0x0c2c2404, 0x0c041414, 0x0c1c0c2c, -//0x0c340404, 0x0c0c3414, 0x0c242424, 0x0c341c3c, 0x0c0c0c0c, 0x0c240424, 0x0c3c3434, 0x0c142c0c, -//0x0c2c1c1c, 0x14041434, 0x14140404, 0x142c3c1c, 0x14042c2c, 0x141c2404, 0x14341414, 0x14040c2c, -//0x141c0404, 0x14343414, 0x140c2424, 0x14241c3c, 0x143c0c0c, 0x14140424, 0x1c243434, 0x1c04240c, -//0x1c141c1c, 0x1c2c0c34, 0x1c040404, 0x1c1c341c, 0x1c2c2c2c, 0x1c041c04, 0x1c1c1414, 0x1c34042c, -//0x1c0c3c04, 0x1c242c14, 0x1c342424, 0x1c0c143c, 0x24240c0c, 0x243c0424, 0x24143434, 0x242c240c, -//0x24041c1c, 0x24140c34, 0x242c0404, 0x2404341c, 0x241c242c, 0x24341c04, 0x24040c14, 0x241c042c, -//0x24343404, 0x2c0c2c14, 0x2c241c24, 0x2c3c143c, 0x2c0c040c, 0x2c243c24, 0x2c042c34, 0x2c14240c, -//0x2c2c141c, 0x2c040c34, 0x2c1c0404, 0x2c2c341c, 0x2c04242c, 0x2c1c1c04, 0x2c340c14, 0x340c042c, -//0x34243404, 0x34342c14, 0x340c1c24, 0x34240c3c, 0x343c040c, 0x34143424, 0x342c2c34, 0x34041c0c, -//0x3414141c, 0x342c0434, 0x34043c04, 0x341c2c1c, 0x3434242c, 0x3c041404, 0x3c1c0c14, 0x3c34042c, -//0x3c0c3404, 0x3c242414, 0x3c3c1c24, 0x040c0c3c, 0x0424040c, 0x04043424, 0x04142c34, 0x042c1c0c, -//0x0404141c, 0x04140434, 0x042c3c04, 0x04042c1c, 0x041c1c2c, 0x04341404, 0x040c0414, 0x041c3c2c, -//0x04342c04, 0x040c2414, 0x04241424, 0x043c0c3c, 0x0414040c, 0x042c3424, 0x04042434, 0x04141c0c, -//0x0c2c0c1c, 0x0c040434, 0x0c1c3404, 0x0c342c1c, 0x0c041c2c, 0x0c1c1404, 0x0c340414, 0x0c0c3c2c, -//0x0c242c04, 0x0c3c2414, 0x0c0c1424, 0x0c24043c, 0x0c043c0c, 0x14142c24, 0x142c2434, 0x1404140c, -//0x14140c1c, 0x142c0434, 0x14043404, 0x141c241c, 0x14341c2c, 0x140c0c04, 0x141c0414, 0x1434342c, -//0x140c2c04, 0x14241c14, 0x143c1424, 0x1c14043c, 0x1c243c0c, 0x1c042c24, 0x1c142434, 0x1c2c140c, -//0x1c040c1c, 0x1c1c3c34, 0x1c342c04, 0x1c04241c, 0x1c1c142c, 0x1c340c04, 0x1c0c0414, 0x1c24342c, -//0x1c3c2404, 0x240c1c14, 0x24240c24, 0x2404043c, 0x2414340c, 0x242c2c24, 0x24041c34, 0x2414140c, -//0x242c041c, 0x24043c34, 0x241c2c04, 0x2434241c, 0x240c142c, 0x241c0c04, 0x2c340414, 0x2c0c342c, -//}; +static const __device__ uint32_t iq3s_grid[512] = { + 0x01010101, 0x0105070f, 0x010f030d, 0x0105090b, 0x010f0509, 0x01050109, 0x010f0707, 0x01050307, + 0x010f0905, 0x01050505, 0x010f0105, 0x01050703, 0x010d0303, 0x01050b03, 0x010d0501, 0x01050101, + 0x010d0701, 0x0105030f, 0x010d0b0d, 0x0105050b, 0x010d0109, 0x01050709, 0x010d0307, 0x01030b07, + 0x010b0505, 0x01030105, 0x010b0705, 0x01030303, 0x010b0b03, 0x01030503, 0x010b0101, 0x01030701, + 0x010b0301, 0x01030b0f, 0x010b050d, 0x0103010b, 0x01090709, 0x01030309, 0x01090b07, 0x01030507, + 0x01090105, 0x01030705, 0x01090305, 0x01030b03, 0x01090503, 0x01030103, 0x01090701, 0x01030301, + 0x01090b01, 0x0103050f, 0x0109010d, 0x0103070b, 0x01090309, 0x01030b09, 0x01090507, 0x01030107, + 0x01090705, 0x01030305, 0x01070d05, 0x01010503, 0x01070103, 0x01010703, 0x01070301, 0x01010d01, + 0x01070501, 0x0101010f, 0x0107070d, 0x0101030b, 0x01070d09, 0x01010509, 0x01070107, 0x01010907, + 0x01070305, 0x01010d05, 0x01070505, 0x01010103, 0x01070903, 0x01010303, 0x01070d01, 0x01010501, + 0x01070101, 0x0101090f, 0x0105030d, 0x01010d0b, 0x01050509, 0x01010109, 0x01050907, 0x01010307, + 0x01050d05, 0x01010505, 0x01050105, 0x01010903, 0x01050303, 0x010f0d03, 0x01050501, 0x010f0101, + 0x01050901, 0x010f030f, 0x03050d0d, 0x030f050b, 0x03050109, 0x030f0909, 0x03050307, 0x030d0d07, + 0x03050505, 0x030d0105, 0x03050905, 0x030d0303, 0x03050f03, 0x030d0503, 0x03050101, 0x030d0901, + 0x03050301, 0x030d0f0f, 0x0305050d, 0x030b010b, 0x03030909, 0x030b0309, 0x03030f07, 0x030b0507, + 0x03030105, 0x030b0905, 0x03030305, 0x030b0f03, 0x03030703, 0x030b0103, 0x03030901, 0x03090301, + 0x03030f01, 0x0309070f, 0x0303010d, 0x0309090b, 0x03030309, 0x03090f09, 0x03030707, 0x03090107, + 0x03030905, 0x03090505, 0x03030f05, 0x03090703, 0x03030103, 0x03090903, 0x03030501, 0x03090f01, + 0x03030701, 0x0309030f, 0x0303090d, 0x0309050b, 0x03030f09, 0x03070709, 0x03010307, 0x03070907, + 0x03010505, 0x03070105, 0x03010705, 0x03070303, 0x03010903, 0x03070503, 0x03010101, 0x03070701, + 0x03010301, 0x0307090f, 0x0301050d, 0x0307010b, 0x03010709, 0x03070309, 0x03010b07, 0x03070507, + 0x03010105, 0x03070705, 0x03010305, 0x03070b03, 0x03010503, 0x03050103, 0x03010701, 0x03050301, + 0x03010b01, 0x0305050f, 0x0301010d, 0x0305070b, 0x03010309, 0x03050b09, 0x03010507, 0x03050107, + 0x030f0705, 0x03050305, 0x030f0b05, 0x03050503, 0x030f0103, 0x03050703, 0x030f0301, 0x03050b01, + 0x030f0501, 0x0305010f, 0x030f070d, 0x0505030b, 0x050d0b09, 0x05050509, 0x050d0107, 0x05050707, + 0x050d0305, 0x05050b05, 0x050d0505, 0x05050103, 0x050d0703, 0x05050303, 0x050b0b01, 0x05030501, + 0x050b0101, 0x0503070f, 0x050b030d, 0x05030d0b, 0x050b0509, 0x05030109, 0x050b0707, 0x05030307, + 0x050b0d05, 0x05030505, 0x05090105, 0x05030903, 0x05090303, 0x05030d03, 0x05090501, 0x05030101, + 0x05090901, 0x0503030f, 0x05090d0d, 0x0503050b, 0x05090109, 0x05030909, 0x05090307, 0x05030d07, + 0x05090505, 0x05030105, 0x05090905, 0x05030303, 0x05090d03, 0x05030503, 0x05090101, 0x05030901, + 0x05090301, 0x05010d0f, 0x0507050d, 0x0501010b, 0x05070909, 0x05010309, 0x05070d07, 0x05010507, + 0x05070105, 0x05010905, 0x05070305, 0x05010d03, 0x05070503, 0x05010103, 0x05070901, 0x05010301, + 0x05070f01, 0x0501050f, 0x0507010d, 0x0501090b, 0x05070309, 0x05010f09, 0x05070507, 0x05010107, + 0x05050905, 0x05010305, 0x05050f05, 0x05010503, 0x05050103, 0x05010903, 0x05050301, 0x05010f01, + 0x05050501, 0x0501010f, 0x0505090d, 0x050f030b, 0x05050f09, 0x050f0709, 0x05050107, 0x050f0907, + 0x05050305, 0x050f0f05, 0x05050705, 0x050f0103, 0x05050903, 0x050f0503, 0x05050f01, 0x050d0701, + 0x05050101, 0x050d090f, 0x0505050d, 0x050d0f0b, 0x07050709, 0x070d0109, 0x07050907, 0x070d0507, + 0x07050f05, 0x070d0705, 0x07030305, 0x070b0903, 0x07030503, 0x070b0f03, 0x07030701, 0x070b0301, + 0x07030901, 0x070b050f, 0x0703010d, 0x070b070b, 0x07030309, 0x07090909, 0x07030507, 0x07090107, + 0x07030705, 0x07090305, 0x07030b05, 0x07090503, 0x07030103, 0x07090703, 0x07030301, 0x07090b01, + 0x07030501, 0x0709010f, 0x0703070d, 0x0709030b, 0x07030b09, 0x07090509, 0x07030107, 0x07090707, + 0x07030305, 0x07090b05, 0x07030505, 0x07090103, 0x07010703, 0x07070303, 0x07010b01, 0x07070501, + 0x07010101, 0x0707070f, 0x0701030d, 0x07070b0b, 0x07010509, 0x07070109, 0x07010707, 0x07070307, + 0x07010b05, 0x07070505, 0x07010105, 0x07070703, 0x07010303, 0x07070b03, 0x07010501, 0x07070101, + 0x07010701, 0x0707030f, 0x07010b0d, 0x0705050b, 0x09010109, 0x09050709, 0x09010307, 0x09050b07, + 0x09010505, 0x09050105, 0x09010705, 0x09050303, 0x09010d03, 0x09050503, 0x09010101, 0x09050701, + 0x090f0301, 0x09050d0f, 0x090f050d, 0x0905010b, 0x090f0909, 0x09050309, 0x090f0d07, 0x09050507, + 0x090f0105, 0x09050905, 0x090d0305, 0x09050d03, 0x090d0503, 0x09050103, 0x090d0901, 0x09050301, + 0x090d0d01, 0x0905050f, 0x090d010d, 0x0905090b, 0x090d0309, 0x09030d09, 0x090b0507, 0x09030107, + 0x090b0905, 0x09030305, 0x090b0d05, 0x09030503, 0x090b0103, 0x09030903, 0x090b0301, 0x09030d01, + 0x090b0501, 0x0903010f, 0x0909090d, 0x0903030b, 0x09090d09, 0x09030509, 0x09090107, 0x09030907, + 0x09090305, 0x09030f05, 0x09090505, 0x09030103, 0x09090903, 0x09030303, 0x09090f01, 0x09030501, + 0x09090101, 0x0903090f, 0x0909030d, 0x09030f0b, 0x09090509, 0x0b030109, 0x0b090907, 0x0b030307, + 0x0b070f05, 0x0b010505, 0x0b070105, 0x0b010903, 0x0b070303, 0x0b010f03, 0x0b070701, 0x0b010101, + 0x0b070901, 0x0b01030f, 0x0b070f0d, 0x0b01070b, 0x0b070109, 0x0b010909, 0x0b070507, 0x0b010f07, + 0x0b070705, 0x0b010105, 0x0b070905, 0x0b010503, 0x0b070f03, 0x0b010703, 0x0b070301, 0x0b010901, + 0x0b050501, 0x0b010f0f, 0x0b05070d, 0x0b01030b, 0x0b050909, 0x0d010509, 0x0d050f07, 0x0d010707, + 0x0d050305, 0x0d010905, 0x0d050505, 0x0d0f0103, 0x0d050703, 0x0d0f0303, 0x0d050901, 0x0d0f0501, + 0x0d050101, 0x0d0f070f, 0x0d05030d, 0x0d0f0b0b, 0x0d050509, 0x0d0f0109, 0x0d050707, 0x0d0d0307, + 0x0d050b05, 0x0d0d0505, 0x0d050105, 0x0d0d0703, 0x0d050303, 0x0d0d0b03, 0x0d050501, 0x0d0d0101, + 0x0d050701, 0x0d0b030f, 0x0d030b0d, 0x0d0b050b, 0x0d030109, 0x0d0b0709, 0x0f030307, 0x0f0b0b07, + 0x0f030505, 0x0f0b0105, 0x0f030705, 0x0f0b0303, 0x0f030b03, 0x0f090503, 0x0f030101, 0x0f090701, + 0x0f030301, 0x0f090b0f, 0x0f03050d, 0x0f09010b, 0x0f030709, 0x0f090309, 0x0f030b07, 0x0f090507, + 0x0f030105, 0x0f090705, 0x0f030305, 0x0f090b03, 0x0f030503, 0x0f090103, 0x0f030701, 0x0f090301, +}; static const __device__ uint64_t iq1s_grid[512] = { @@ -2371,8 +2371,9 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds } -#define IQ3S_MULTIPLIER 518559 -static const __device__ uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15}; +// On CUDA it is fuster to use a lookup table instead of directly computing using these +//#define IQ3S_MULTIPLIER 518559 +//static const __device__ uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15}; template static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { @@ -2386,16 +2387,22 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_ const int ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint8_t * qs = x[i].qs + 8*ib; - int32_t aux32[2]; - const int8_t * grid = (const int8_t *)aux32; + //int32_t aux32[2]; + //const int8_t * grid = (const int8_t *)aux32; const int is = (32*ib + 8*il)/IQ3S_BLOCK_SIZE; const float d = (float)x[i].d * (1 + 2*((x[i].scales[is/2] >> 4*(is%2)) & 0xf)); const uint8_t signs = x[i].signs[4*ib + il]; - aux32[0] = (((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); - aux32[1] = (((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); - for (int j = 0; j < 8; ++j) { - y[j] = d * iq3s_values[grid[j]] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256))); + for (int j = 0; j < 4; ++j) { + y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); + y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); } + //aux32[0] = (((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); + //aux32[1] = (((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); + //for (int j = 0; j < 8; ++j) { + // y[j] = d * iq3s_values[grid[j]] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + //} #else assert(false); #endif @@ -5201,8 +5208,8 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( #if QK_K == 256 const block_iq3_s * bq2 = (const block_iq3_s *) vbq; - uint32_t aux32[2]; - uint8_t * aux8 = (uint8_t *)aux32; + //uint32_t aux32[2]; + //uint8_t * aux8 = (uint8_t *)aux32; const int ib32 = iqs; const uint8_t * qs = bq2->qs + 8*ib32; @@ -5213,13 +5220,15 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( int sumi[2] = {0, 0}; #endif for (int l = 0; l < 4; ++l) { - aux32[0] = (((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); - aux32[1] = (((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); - for (int j = 0; j < 8; ++j) aux8[j] = iq3s_values[aux8[j]]; + //aux32[0] = (((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); + //aux32[1] = (((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); + //for (int j = 0; j < 8; ++j) aux8[j] = iq3s_values[aux8[j]]; uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201); uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201); - const int grid_l = __vsub4(aux32[0] ^ signs0, signs0); - const int grid_h = __vsub4(aux32[1] ^ signs1, signs1); + //const int grid_l = __vsub4(aux32[0] ^ signs0, signs0); + //const int grid_h = __vsub4(aux32[1] ^ signs1, signs1); + const int grid_l = __vsub4(iq3s_grid[qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)] ^ signs0, signs0); + const int grid_h = __vsub4(iq3s_grid[qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)] ^ signs1, signs1); #if IQ3S_BLOCK_SIZE == 32 sumi = __dp4a(grid_l, *((int *)q8+0), sumi); sumi = __dp4a(grid_h, *((int *)q8+1), sumi); From 31cecc8734537b38052a61c7060fabcf6094e35b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 5 Mar 2024 10:19:44 +0200 Subject: [PATCH 24/24] iq3_s_mult_shuffle: use lookup table on Metal ~4% faster TG and ~2% faster PP that way. --- ggml-metal.metal | 130 ++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 107 insertions(+), 23 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 8c3ac9e34..90fe88884 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2546,8 +2546,10 @@ typedef struct { uint8_t signs[QK_K/8]; uint8_t scales[IQ3S_N_SCALE]; } block_iq3_s; -#define IQ3S_MULTIPLIER 518559 -constexpr constant static uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15}; + +// When a shuffle is involved in the codebook, on Metal it is faster to use a lookup table +//#define IQ3S_MULTIPLIER 518559 +//constexpr constant static uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15}; typedef struct { half d; @@ -4085,6 +4087,73 @@ constexpr constant static uint32_t iq3xxs_grid[256] = { 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, }; +constexpr constant static uint32_t iq3s_grid[512] = { + 0x01010101, 0x0105070f, 0x010f030d, 0x0105090b, 0x010f0509, 0x01050109, 0x010f0707, 0x01050307, + 0x010f0905, 0x01050505, 0x010f0105, 0x01050703, 0x010d0303, 0x01050b03, 0x010d0501, 0x01050101, + 0x010d0701, 0x0105030f, 0x010d0b0d, 0x0105050b, 0x010d0109, 0x01050709, 0x010d0307, 0x01030b07, + 0x010b0505, 0x01030105, 0x010b0705, 0x01030303, 0x010b0b03, 0x01030503, 0x010b0101, 0x01030701, + 0x010b0301, 0x01030b0f, 0x010b050d, 0x0103010b, 0x01090709, 0x01030309, 0x01090b07, 0x01030507, + 0x01090105, 0x01030705, 0x01090305, 0x01030b03, 0x01090503, 0x01030103, 0x01090701, 0x01030301, + 0x01090b01, 0x0103050f, 0x0109010d, 0x0103070b, 0x01090309, 0x01030b09, 0x01090507, 0x01030107, + 0x01090705, 0x01030305, 0x01070d05, 0x01010503, 0x01070103, 0x01010703, 0x01070301, 0x01010d01, + 0x01070501, 0x0101010f, 0x0107070d, 0x0101030b, 0x01070d09, 0x01010509, 0x01070107, 0x01010907, + 0x01070305, 0x01010d05, 0x01070505, 0x01010103, 0x01070903, 0x01010303, 0x01070d01, 0x01010501, + 0x01070101, 0x0101090f, 0x0105030d, 0x01010d0b, 0x01050509, 0x01010109, 0x01050907, 0x01010307, + 0x01050d05, 0x01010505, 0x01050105, 0x01010903, 0x01050303, 0x010f0d03, 0x01050501, 0x010f0101, + 0x01050901, 0x010f030f, 0x03050d0d, 0x030f050b, 0x03050109, 0x030f0909, 0x03050307, 0x030d0d07, + 0x03050505, 0x030d0105, 0x03050905, 0x030d0303, 0x03050f03, 0x030d0503, 0x03050101, 0x030d0901, + 0x03050301, 0x030d0f0f, 0x0305050d, 0x030b010b, 0x03030909, 0x030b0309, 0x03030f07, 0x030b0507, + 0x03030105, 0x030b0905, 0x03030305, 0x030b0f03, 0x03030703, 0x030b0103, 0x03030901, 0x03090301, + 0x03030f01, 0x0309070f, 0x0303010d, 0x0309090b, 0x03030309, 0x03090f09, 0x03030707, 0x03090107, + 0x03030905, 0x03090505, 0x03030f05, 0x03090703, 0x03030103, 0x03090903, 0x03030501, 0x03090f01, + 0x03030701, 0x0309030f, 0x0303090d, 0x0309050b, 0x03030f09, 0x03070709, 0x03010307, 0x03070907, + 0x03010505, 0x03070105, 0x03010705, 0x03070303, 0x03010903, 0x03070503, 0x03010101, 0x03070701, + 0x03010301, 0x0307090f, 0x0301050d, 0x0307010b, 0x03010709, 0x03070309, 0x03010b07, 0x03070507, + 0x03010105, 0x03070705, 0x03010305, 0x03070b03, 0x03010503, 0x03050103, 0x03010701, 0x03050301, + 0x03010b01, 0x0305050f, 0x0301010d, 0x0305070b, 0x03010309, 0x03050b09, 0x03010507, 0x03050107, + 0x030f0705, 0x03050305, 0x030f0b05, 0x03050503, 0x030f0103, 0x03050703, 0x030f0301, 0x03050b01, + 0x030f0501, 0x0305010f, 0x030f070d, 0x0505030b, 0x050d0b09, 0x05050509, 0x050d0107, 0x05050707, + 0x050d0305, 0x05050b05, 0x050d0505, 0x05050103, 0x050d0703, 0x05050303, 0x050b0b01, 0x05030501, + 0x050b0101, 0x0503070f, 0x050b030d, 0x05030d0b, 0x050b0509, 0x05030109, 0x050b0707, 0x05030307, + 0x050b0d05, 0x05030505, 0x05090105, 0x05030903, 0x05090303, 0x05030d03, 0x05090501, 0x05030101, + 0x05090901, 0x0503030f, 0x05090d0d, 0x0503050b, 0x05090109, 0x05030909, 0x05090307, 0x05030d07, + 0x05090505, 0x05030105, 0x05090905, 0x05030303, 0x05090d03, 0x05030503, 0x05090101, 0x05030901, + 0x05090301, 0x05010d0f, 0x0507050d, 0x0501010b, 0x05070909, 0x05010309, 0x05070d07, 0x05010507, + 0x05070105, 0x05010905, 0x05070305, 0x05010d03, 0x05070503, 0x05010103, 0x05070901, 0x05010301, + 0x05070f01, 0x0501050f, 0x0507010d, 0x0501090b, 0x05070309, 0x05010f09, 0x05070507, 0x05010107, + 0x05050905, 0x05010305, 0x05050f05, 0x05010503, 0x05050103, 0x05010903, 0x05050301, 0x05010f01, + 0x05050501, 0x0501010f, 0x0505090d, 0x050f030b, 0x05050f09, 0x050f0709, 0x05050107, 0x050f0907, + 0x05050305, 0x050f0f05, 0x05050705, 0x050f0103, 0x05050903, 0x050f0503, 0x05050f01, 0x050d0701, + 0x05050101, 0x050d090f, 0x0505050d, 0x050d0f0b, 0x07050709, 0x070d0109, 0x07050907, 0x070d0507, + 0x07050f05, 0x070d0705, 0x07030305, 0x070b0903, 0x07030503, 0x070b0f03, 0x07030701, 0x070b0301, + 0x07030901, 0x070b050f, 0x0703010d, 0x070b070b, 0x07030309, 0x07090909, 0x07030507, 0x07090107, + 0x07030705, 0x07090305, 0x07030b05, 0x07090503, 0x07030103, 0x07090703, 0x07030301, 0x07090b01, + 0x07030501, 0x0709010f, 0x0703070d, 0x0709030b, 0x07030b09, 0x07090509, 0x07030107, 0x07090707, + 0x07030305, 0x07090b05, 0x07030505, 0x07090103, 0x07010703, 0x07070303, 0x07010b01, 0x07070501, + 0x07010101, 0x0707070f, 0x0701030d, 0x07070b0b, 0x07010509, 0x07070109, 0x07010707, 0x07070307, + 0x07010b05, 0x07070505, 0x07010105, 0x07070703, 0x07010303, 0x07070b03, 0x07010501, 0x07070101, + 0x07010701, 0x0707030f, 0x07010b0d, 0x0705050b, 0x09010109, 0x09050709, 0x09010307, 0x09050b07, + 0x09010505, 0x09050105, 0x09010705, 0x09050303, 0x09010d03, 0x09050503, 0x09010101, 0x09050701, + 0x090f0301, 0x09050d0f, 0x090f050d, 0x0905010b, 0x090f0909, 0x09050309, 0x090f0d07, 0x09050507, + 0x090f0105, 0x09050905, 0x090d0305, 0x09050d03, 0x090d0503, 0x09050103, 0x090d0901, 0x09050301, + 0x090d0d01, 0x0905050f, 0x090d010d, 0x0905090b, 0x090d0309, 0x09030d09, 0x090b0507, 0x09030107, + 0x090b0905, 0x09030305, 0x090b0d05, 0x09030503, 0x090b0103, 0x09030903, 0x090b0301, 0x09030d01, + 0x090b0501, 0x0903010f, 0x0909090d, 0x0903030b, 0x09090d09, 0x09030509, 0x09090107, 0x09030907, + 0x09090305, 0x09030f05, 0x09090505, 0x09030103, 0x09090903, 0x09030303, 0x09090f01, 0x09030501, + 0x09090101, 0x0903090f, 0x0909030d, 0x09030f0b, 0x09090509, 0x0b030109, 0x0b090907, 0x0b030307, + 0x0b070f05, 0x0b010505, 0x0b070105, 0x0b010903, 0x0b070303, 0x0b010f03, 0x0b070701, 0x0b010101, + 0x0b070901, 0x0b01030f, 0x0b070f0d, 0x0b01070b, 0x0b070109, 0x0b010909, 0x0b070507, 0x0b010f07, + 0x0b070705, 0x0b010105, 0x0b070905, 0x0b010503, 0x0b070f03, 0x0b010703, 0x0b070301, 0x0b010901, + 0x0b050501, 0x0b010f0f, 0x0b05070d, 0x0b01030b, 0x0b050909, 0x0d010509, 0x0d050f07, 0x0d010707, + 0x0d050305, 0x0d010905, 0x0d050505, 0x0d0f0103, 0x0d050703, 0x0d0f0303, 0x0d050901, 0x0d0f0501, + 0x0d050101, 0x0d0f070f, 0x0d05030d, 0x0d0f0b0b, 0x0d050509, 0x0d0f0109, 0x0d050707, 0x0d0d0307, + 0x0d050b05, 0x0d0d0505, 0x0d050105, 0x0d0d0703, 0x0d050303, 0x0d0d0b03, 0x0d050501, 0x0d0d0101, + 0x0d050701, 0x0d0b030f, 0x0d030b0d, 0x0d0b050b, 0x0d030109, 0x0d0b0709, 0x0f030307, 0x0f0b0b07, + 0x0f030505, 0x0f0b0105, 0x0f030705, 0x0f0b0303, 0x0f030b03, 0x0f090503, 0x0f030101, 0x0f090701, + 0x0f030301, 0x0f090b0f, 0x0f03050d, 0x0f09010b, 0x0f030709, 0x0f090309, 0x0f030b07, 0x0f090507, + 0x0f030105, 0x0f090705, 0x0f030305, 0x0f090b03, 0x0f030503, 0x0f090103, 0x0f030701, 0x0f090301, +}; + #define NGRID_IQ1S 512 constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = { 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000, @@ -4694,20 +4763,23 @@ void kernel_mul_mv_iq3_s_f32_impl( { int nval = 8; int pos = (32*sgitg + tiisg)*nval; - uint32_t aux32; - thread int8_t * q = (thread int8_t *)&aux32; for (int i = 0; i < nval; ++i) { - aux32 = (IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f; - for (int k = 0; k < 4; ++k) q[k] = iq3s_values[q[k]]; - values[pos + i] = aux32; + values[pos + i] = iq3s_grid[pos + i]; } + //uint32_t aux32; + //thread int8_t * q = (thread int8_t *)&aux32; + //for (int i = 0; i < nval; ++i) { + // aux32 = (IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f; + // for (int k = 0; k < 4; ++k) q[k] = iq3s_values[q[k]]; + // values[pos + i] = aux32; + //} threadgroup_barrier(mem_flags::mem_threadgroup); } const int ix = tiisg; - uint32_t aux32[2]; - thread const int8_t * grid = (thread const int8_t *)aux32; + //uint32_t aux32[2]; + //thread const int8_t * grid = (thread const int8_t *)aux32; device const float * y4 = y + 32 * ix; @@ -4735,11 +4807,11 @@ void kernel_mul_mv_iq3_s_f32_impl( float2 sum = {0}; for (int l = 0; l < 4; ++l) { // This is slower than pre-computing the grid in shared memory and loading from there - //aux32[0] = ((IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))) & 0x0f0f0f0f) | 0x01010101; - //aux32[1] = ((IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))) & 0x0f0f0f0f) | 0x01010101; + //aux32[0] = (IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))) & 0x0f0f0f0f; + //aux32[1] = (IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))) & 0x0f0f0f0f; //for (int j = 0; j < 4; ++j) { - // sum[0] += yl[8*l + j + 0] * grid[j+0] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); - // sum[1] += yl[8*l + j + 4] * grid[j+4] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); + // sum[0] += yl[8*l + j + 0] * iq3s_values[grid[j+0]] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); + // sum[1] += yl[8*l + j + 4] * iq3s_values[grid[j+4]] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); //} threadgroup const uint8_t * grid1 = (threadgroup const uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))); threadgroup const uint8_t * grid2 = (threadgroup const uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))); @@ -5655,20 +5727,32 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & device const uint8_t * signs = xb->signs + 4*ib32 + 2*il; const uint8_t qh = xb->qh[ib32] >> 4*il; const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)); - uint32_t aux32[2]; - thread const int8_t * grid = (thread const int8_t *)aux32; - aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f; - aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f; + constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256))); + constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256))); for (int i = 0; i < 4; ++i) { - reg[0][i] = dl * iq3s_values[grid[i+0]] * select(1, -1, signs[0] & kmask_iq2xs[i+0]); - reg[1][i] = dl * iq3s_values[grid[i+4]] * select(1, -1, signs[0] & kmask_iq2xs[i+4]); + reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]); + reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]); } - aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+2] | ((qh << 6) & 256))) & 0x0f0f0f0f; - aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+3] | ((qh << 5) & 256))) & 0x0f0f0f0f; + grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256))); + grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256))); for (int i = 0; i < 4; ++i) { - reg[2][i] = dl * iq3s_values[grid[i+0]] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); - reg[3][i] = dl * iq3s_values[grid[i+4]] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); + reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); + reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); } + //uint32_t aux32[2]; + //thread const int8_t * grid = (thread const int8_t *)aux32; + //aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f; + //aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f; + //for (int i = 0; i < 4; ++i) { + // reg[0][i] = dl * iq3s_values[grid[i+0]] * select(1, -1, signs[0] & kmask_iq2xs[i+0]); + // reg[1][i] = dl * iq3s_values[grid[i+4]] * select(1, -1, signs[0] & kmask_iq2xs[i+4]); + //} + //aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+2] | ((qh << 6) & 256))) & 0x0f0f0f0f; + //aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+3] | ((qh << 5) & 256))) & 0x0f0f0f0f; + //for (int i = 0; i < 4; ++i) { + // reg[2][i] = dl * iq3s_values[grid[i+0]] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); + // reg[3][i] = dl * iq3s_values[grid[i+4]] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); + //} } template