diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c207ff87a..c4e81122b 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -566,7 +566,8 @@ static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + IQ3S_N typedef struct { half d; uint8_t qs[QK_K/8]; - uint8_t scales[QK_K/16]; + uint8_t qh[QK_K/32]; + uint8_t scales[QK_K/32]; } block_iq1_s; static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding"); @@ -1723,9 +1724,8 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_ const int ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const int i8 = 4*ib+il; - uint8_t h = x[i].scales[i8/2] >> 4*(i8%2); - const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | ((h & 8) << 5))); - const float d = (float)x[i].d * (2*(h & 7) + 1); + const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | (((x[i].qh[i8/4] >> 2*(i8%4)) & 3) << 8))); + const float d = (float)x[i].d * (2*((x[i].scales[ib] >> 4*(il/2)) & 0xf) + 1); for (int j = 0; j < 8; ++j) y[j] = d * grid[j]; #else assert(false); diff --git a/ggml-quants.c b/ggml-quants.c index 42d8a5d80..cddbae8b3 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -9996,7 +9996,7 @@ static inline int iq2_grid_size(enum ggml_type type) { GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ2_S); return type == GGML_TYPE_IQ2_XXS ? 256 : type == GGML_TYPE_IQ2_XS ? 512 : - type == GGML_TYPE_IQ1_S ? 512 : 1024; + type == GGML_TYPE_IQ1_S ? NGRID_IQ1S : 1024; } static int iq2_compare_func(const void * left, const void * right) { @@ -10063,39 +10063,71 @@ void iq2xs_init_impl(enum ggml_type type) { 40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048, 42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690, }; - static const uint16_t kgrid_1bit_512[512] = { - 10, 33, 41, 85, 132, 134, 160, 162, 277, 337, 340, 345, 357, 405, 516, 545, - 553, 598, 641, 650, 681, 1042, 1044, 1097, 1169, 1176, 1320, 1345, 1365, 1378, 1434, 1444, - 1545, 1617, 1642, 1685, 2053, 2080, 2089, 2133, 2176, 2182, 2208, 2214, 2306, 2384, 2393, 2440, - 2453, 2581, 2664, 2690, 2721, 4117, 4161, 4182, 4184, 4261, 4357, 4369, 4372, 4377, 4390, 4422, - 4432, 4437, 4449, 4457, 4485, 4497, 4505, 4629, 4677, 4696, 4774, 5205, 5217, 5225, 5386, 5397, - 5409, 5445, 5457, 5460, 5461, 5462, 5465, 5472, 5477, 5525, 5545, 5650, 5668, 5717, 5729, 5769, - 5777, 6212, 6234, 6244, 6293, 6424, 6482, 6485, 6502, 6505, 6529, 6538, 6565, 6656, 6682, 6788, - 6806, 6820, 8218, 8224, 8226, 8232, 8277, 8326, 8354, 8469, 8521, 8530, 8549, 8596, 8737, 8794, - 9221, 9253, 9348, 9369, 9380, 9474, 9557, 9633, 9732, 9753, 9793, 9830, 9862, 9880, 10240, 10272, - 10282, 10321, 10406, 10517, 10530, 10566, 10585, 10645, 10896, 16466, 16468, 16473, 16485, 16646, 16660, 16665, - 16725, 16793, 16806, 16914, 16969, 16977, 16996, 17028, 17057, 17408, 17416, 17434, 17493, 17512, 17578, 17685, - 17696, 17733, 17745, 17748, 17749, 17750, 17753, 17765, 17794, 17813, 17946, 17984, 18005, 18072, 18453, 18529, - 18569, 18722, 18756, 18762, 18773, 18794, 18833, 18853, 18945, 19026, 19033, 19077, 20489, 20497, 20500, 20517, - 20565, 20586, 20610, 20633, 20757, 20769, 20776, 20805, 20817, 20820, 20821, 20822, 20825, 20837, 20864, 20872, - 20885, 20896, 21002, 21029, 21077, 21146, 21510, 21525, 21573, 21585, 21588, 21589, 21590, 21593, 21605, 21653, - 21665, 21765, 21777, 21780, 21781, 21782, 21785, 21797, 21825, 21828, 21829, 21830, 21833, 21840, 21841, 21842, - 21844, 21846, 21848, 21849, 21850, 21857, 21860, 21861, 21862, 21865, 21893, 21905, 21908, 21909, 21910, 21913, - 21925, 22024, 22037, 22085, 22097, 22100, 22101, 22102, 22105, 22117, 22165, 22545, 22566, 22568, 22594, 22608, - 22613, 22676, 22697, 22793, 22805, 22853, 22865, 22868, 22869, 22870, 22873, 22885, 22933, 22946, 23046, 23072, - 23125, 23209, 24597, 24640, 24665, 24673, 24725, 24833, 24840, 24869, 24917, 24934, 24965, 25001, 25108, 25110, - 25152, 25184, 25192, 25234, 25616, 25618, 25625, 25685, 25704, 25738, 25744, 25770, 25877, 25897, 25925, 25937, - 25940, 25941, 25942, 25945, 25957, 25986, 26005, 26186, 26197, 26276, 26632, 26634, 26725, 26757, 26770, 26885, - 26965, 26976, 26986, 27032, 27153, 27174, 27200, 27208, 27240, 27269, 27282, 27290, 32778, 32800, 32802, 32808, - 32810, 32853, 32904, 32922, 32930, 32932, 33105, 33110, 33112, 33125, 33157, 33280, 33288, 33301, 33312, 33320, - 33424, 33797, 33829, 33858, 34068, 34133, 34146, 34176, 34217, 34306, 34342, 34441, 34454, 34468, 34832, 34918, - 34965, 34984, 35094, 35137, 35161, 35208, 35232, 35332, 35338, 35368, 35429, 36932, 36934, 36953, 37009, 37125, - 37136, 37138, 37145, 37157, 37205, 37220, 37258, 37290, 37444, 37446, 37465, 37478, 37525, 37905, 37968, 37973, - 38040, 38054, 38145, 38154, 38165, 38180, 38186, 38213, 38225, 38228, 38229, 38230, 38233, 38245, 38293, 38485, - 38504, 38530, 38938, 38985, 38993, 39012, 39040, 39173, 39192, 39253, 39265, 39301, 39316, 39322, 39442, 39497, - 39504, 39590, 40970, 40984, 40992, 41002, 41045, 41120, 41128, 41237, 41289, 41297, 41317, 41364, 41366, 41514, - 41557, 41633, 41989, 42021, 42056, 42068, 42074, 42113, 42242, 42265, 42274, 42325, 42340, 42402, 42501, 42512, - 42533, 42624, 42632, 42666, 43040, 43093, 43106, 43168, 43176, 43264, 43286, 43345, 43429, 43590, 43618, 43680, + static const uint16_t kgrid_1bit_512[NGRID_IQ1S] = { + 5, 32, 40, 89, 101, 128, 138, 149, 160, 162, 170, 273, 281, 294, 329, 336, + 338, 341, 344, 346, 353, 356, 389, 401, 404, 409, 421, 517, 552, 584, 586, 593, + 640, 642, 661, 672, 674, 1108, 1110, 1160, 1169, 1192, 1286, 1301, 1306, 1313, 1349, 1361, + 1365, 1369, 1381, 1429, 1441, 1449, 1536, 1561, 1620, 1622, 1669, 1704, 1706, 2048, 2080, 2082, + 2122, 2129, 2176, 2178, 2197, 2208, 2210, 2326, 2329, 2341, 2369, 2384, 2389, 2401, 2404, 2409, + 2469, 2562, 2570, 2581, 2592, 2594, 2600, 2602, 2629, 2649, 2661, 2696, 2698, 2705, 4113, 4133, + 4181, 4186, 4193, 4226, 4249, 4261, 4353, 4356, 4358, 4361, 4370, 4373, 4378, 4385, 4393, 4416, + 4421, 4433, 4437, 4441, 4448, 4450, 4453, 4484, 4489, 4502, 4516, 4625, 4645, 4692, 4694, 4705, + 4753, 4773, 5141, 5153, 5161, 5190, 5193, 5201, 5205, 5208, 5216, 5221, 5253, 5268, 5393, 5397, + 5398, 5401, 5410, 5412, 5442, 5444, 5445, 5450, 5457, 5460, 5461, 5462, 5465, 5473, 5477, 5480, + 5482, 5504, 5510, 5521, 5525, 5528, 5541, 5653, 5697, 5702, 5705, 5712, 5714, 5717, 5720, 5722, + 5732, 5734, 5737, 5781, 6146, 6152, 6181, 6186, 6213, 6228, 6230, 6233, 6241, 6289, 6309, 6314, + 6405, 6420, 6470, 6473, 6481, 6485, 6488, 6490, 6496, 6501, 6533, 6548, 6550, 6553, 6561, 6664, + 6678, 6741, 6753, 6786, 8194, 8213, 8261, 8281, 8294, 8328, 8330, 8337, 8360, 8464, 8472, 8485, + 8533, 8545, 8548, 8581, 8596, 8613, 8725, 8738, 8776, 8785, 8793, 8805, 8832, 8834, 8858, 8864, + 8866, 8872, 9226, 9236, 9253, 9301, 9321, 9381, 9477, 9505, 9542, 9545, 9553, 9556, 9557, 9562, + 9573, 9622, 9633, 9641, 9728, 9730, 9738, 9770, 9813, 10261, 10272, 10274, 10304, 10321, 10344, 10370, +10376, 10378, 10400, 10402, 10521, 10533, 10576, 10578, 10581, 10598, 10661, 10769, 10856, 10888, 10890, 10897, +16384, 16457, 16469, 16472, 16484, 16529, 16646, 16649, 16661, 16664, 16666, 16681, 16709, 16721, 16724, 16725, +16726, 16729, 16741, 16746, 16769, 16784, 16789, 16804, 16809, 16918, 16928, 16961, 17001, 17033, 17041, 17425, +17428, 17430, 17445, 17473, 17476, 17478, 17490, 17493, 17510, 17513, 17541, 17556, 17558, 17573, 17665, 17668, +17680, 17682, 17685, 17689, 17728, 17730, 17733, 17736, 17738, 17745, 17748, 17749, 17750, 17753, 17762, 17765, +17768, 17796, 17798, 17809, 17813, 17817, 17830, 17929, 17937, 17940, 17957, 17989, 18000, 18002, 18005, 18008, +18010, 18017, 18020, 18022, 18049, 18068, 18070, 18085, 18472, 18512, 18517, 18577, 18694, 18709, 18721, 18757, +18772, 18773, 18778, 18784, 18789, 18817, 18825, 18837, 18849, 18949, 18966, 19017, 19029, 20501, 20513, 20521, +20545, 20550, 20564, 20565, 20569, 20581, 20629, 20741, 20753, 20757, 20758, 20773, 20776, 20805, 20808, 20816, +20817, 20818, 20820, 20821, 20822, 20825, 20833, 20836, 20837, 20838, 20841, 20870, 20881, 20884, 20885, 20889, +20901, 21001, 21012, 21060, 21062, 21073, 21077, 21080, 21082, 21141, 21509, 21520, 21522, 21525, 21528, 21530, +21540, 21542, 21568, 21570, 21573, 21576, 21578, 21585, 21588, 21589, 21590, 21593, 21602, 21605, 21608, 21610, +21636, 21638, 21641, 21648, 21650, 21653, 21656, 21665, 21668, 21670, 21673, 21760, 21762, 21765, 21768, 21770, +21777, 21780, 21781, 21782, 21785, 21793, 21797, 21802, 21825, 21828, 21829, 21830, 21833, 21840, 21841, 21842, +21844, 21845, 21846, 21848, 21849, 21850, 21857, 21860, 21861, 21862, 21865, 21889, 21893, 21896, 21898, 21905, +21908, 21909, 21910, 21913, 21920, 21922, 21925, 21928, 21930, 22017, 22020, 22032, 22034, 22037, 22042, 22052, +22054, 22057, 22080, 22082, 22085, 22088, 22090, 22097, 22100, 22101, 22102, 22105, 22112, 22114, 22117, 22120, +22122, 22148, 22150, 22160, 22162, 22165, 22168, 22170, 22177, 22180, 22182, 22185, 22548, 22550, 22561, 22598, +22601, 22609, 22613, 22616, 22618, 22624, 22630, 22633, 22677, 22793, 22801, 22805, 22808, 22810, 22825, 22849, +22852, 22853, 22858, 22865, 22866, 22868, 22869, 22870, 22873, 22884, 22885, 22890, 22912, 22918, 22929, 22933, +22936, 22938, 22950, 22953, 23060, 23065, 23077, 23110, 23121, 23125, 23130, 23142, 23145, 23169, 23188, 23190, +23205, 24581, 24593, 24596, 24601, 24661, 24664, 24709, 24726, 24729, 24833, 24853, 24865, 24868, 24870, 24873, +24900, 24902, 24913, 24917, 24921, 24933, 24938, 24981, 24993, 24996, 25001, 25105, 25173, 25188, 25221, 25233, +25253, 25621, 25633, 25641, 25669, 25680, 25682, 25685, 25689, 25701, 25749, 25860, 25862, 25865, 25873, 25877, +25882, 25896, 25920, 25922, 25925, 25928, 25930, 25937, 25940, 25941, 25942, 25945, 25952, 25957, 25958, 25988, +25990, 25993, 26001, 26005, 26021, 26117, 26132, 26134, 26137, 26149, 26177, 26180, 26182, 26185, 26192, 26194, +26197, 26200, 26202, 26209, 26217, 26260, 26262, 26265, 26625, 26649, 26709, 26757, 26769, 26789, 26896, 26901, +26918, 26950, 26953, 26965, 26968, 26970, 26977, 27024, 27026, 27029, 27044, 27205, 27220, 27222, 27225, 27237, +27306, 32770, 32776, 32778, 32789, 32800, 32802, 32808, 32810, 32837, 32849, 32854, 32857, 32869, 32896, 32898, +32904, 32906, 32917, 32928, 32936, 32938, 33029, 33046, 33089, 33106, 33109, 33121, 33124, 33126, 33169, 33172, +33174, 33189, 33282, 33314, 33322, 33352, 33354, 33408, 33429, 33448, 33450, 33817, 33872, 33877, 33945, 33954, +34054, 34081, 34086, 34116, 34121, 34129, 34133, 34136, 34138, 34153, 34177, 34194, 34212, 34304, 34325, 34344, +34388, 34390, 34393, 34405, 34437, 34821, 34848, 34850, 34888, 34890, 34922, 34944, 34946, 34965, 34976, 35089, +35092, 35109, 35152, 35157, 35172, 35336, 35338, 35360, 35394, 35409, 35426, 35456, 35464, 35466, 35496, 35498, +36881, 36889, 36901, 36949, 36997, 37009, 37012, 37014, 37029, 37121, 37124, 37153, 37161, 37189, 37204, 37205, +37209, 37221, 37269, 37274, 37284, 37397, 37462, 37481, 37506, 37536, 37538, 37889, 37909, 37956, 37958, 37961, +37973, 37976, 37978, 37988, 37990, 37993, 38037, 38161, 38165, 38170, 38180, 38208, 38213, 38216, 38218, 38225, +38226, 38228, 38229, 38230, 38233, 38241, 38245, 38248, 38250, 38277, 38288, 38293, 38310, 38313, 38405, 38420, +38422, 38425, 38437, 38465, 38468, 38470, 38473, 38480, 38485, 38490, 38500, 38502, 38545, 38548, 38550, 38553, +38565, 38929, 38937, 38977, 38994, 38996, 39013, 39045, 39057, 39080, 39169, 39172, 39184, 39186, 39189, 39201, +39238, 39241, 39253, 39258, 39264, 39270, 39316, 39318, 39321, 39333, 39466, 39493, 39510, 39512, 39525, 39573, +39584, 40960, 40962, 40968, 40970, 40981, 40992, 41000, 41002, 41029, 41049, 41061, 41096, 41109, 41128, 41221, +41236, 41238, 41241, 41253, 41298, 41301, 41304, 41306, 41313, 41321, 41361, 41364, 41369, 41480, 41482, 41512, +41514, 41608, 41610, 41640, 41642, 42021, 42065, 42068, 42070, 42112, 42114, 42122, 42137, 42144, 42146, 42154, +42261, 42264, 42274, 42281, 42305, 42308, 42310, 42313, 42320, 42325, 42390, 42392, 42496, 42498, 42528, 42565, +42577, 42580, 42582, 42585, 42597, 42624, 42645, 43014, 43016, 43029, 43040, 43048, 43050, 43097, 43144, 43157, +43284, 43289, 43301, 43333, 43345, 43350, 43369, 43528, 43530, 43560, 43605, 43650, 43656, 43658, 43682, 43688, }; static const uint16_t kgrid_2bit_1024[1024] = { 0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70, @@ -11408,12 +11440,70 @@ static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const u return grid_index; } +static int iq1_find_best_neighbour2(const uint16_t * restrict neighbours, const uint64_t * restrict grid, + const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L, int ngrid) { + int num_neighbors = neighbours[0]; + GGML_ASSERT(num_neighbors > 0); + float best_score = FLT_MAX; + int grid_index = -1; + for (int j = 1; j <= num_neighbors; ++j) { + const int8_t * pg = (const int8_t *)(grid + neighbours[j]); + float d2 = 0; + for (int i = 0; i < 8; ++i) { + float q = (pg[i] - 3)/2; + float w = weight[i]; + float diff = scale*q - xval[i]; + d2 += w*diff*diff; + } + if (d2 < best_score) { + best_score = d2; + grid_index = neighbours[j]; + } + } + if (grid_index < 0) { + for (int i = 0; i < ngrid; ++i) { + const int8_t * grid_i = (const int8_t *)(grid + i); + float d2 = 0; + for (int j = 0; j < 8; ++j) { + float w = weight[j]; + float q = (grid_i[j] - 3)/2; + float diff = scale*q - xval[i]; + d2 += w*diff*diff; + } + if (d2 < best_score) { + best_score = d2; + grid_index = i; + } + } + } + if (grid_index < 0) { + printf("Oops, did not find grid point\n"); + printf("Have %d neighbours\n", num_neighbors); + for (int j = 1; j <= num_neighbors; ++j) { + const int8_t * pg = (const int8_t *)(grid + neighbours[j]); + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 8; ++i) { + float q = (pg[i] - 3)/2; + float w = weight[i]; + sumqx += w*q*xval[i]; + sumq2 += w*q*q; + } + printf(" neighbour %d: sumqx = %g sumq2 = %g\n", j, (double)sumqx, (double)sumq2); + } + } + GGML_ASSERT(grid_index >= 0); + const int8_t * pg = (const int8_t *)(grid + grid_index); + for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2; + return grid_index; +} + static int iq1_sort_helper(const void * left, const void * right) { const float * l = left; const float * r = right; return *l < *r ? -1 : *l > *r ? 1 : 0; } +#define IQ1S_BLOCK_SIZE 16 static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) { const int gindex = iq2_data_index(GGML_TYPE_IQ1_S); @@ -11432,20 +11522,21 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy block_iq1_s * y = vy; - float scales[QK_K/8]; - float weight[8]; - int8_t L[8]; - float sumx[9]; - float sumw[9]; - float pairs[16]; + float scales[QK_K/IQ1S_BLOCK_SIZE]; + float weight[IQ1S_BLOCK_SIZE]; + int8_t L[IQ1S_BLOCK_SIZE]; + float sumx[IQ1S_BLOCK_SIZE+1]; + float sumw[IQ1S_BLOCK_SIZE+1]; + float pairs[2*IQ1S_BLOCK_SIZE]; int * idx = (int *)(pairs + 1); - uint8_t hbit[QK_K/8]; + uint16_t index[IQ1S_BLOCK_SIZE/8]; for (int ibl = 0; ibl < nbl; ++ibl) { y[ibl].d = GGML_FP32_TO_FP16(0.f); memset(y[ibl].qs, 0, QK_K/8); - memset(y[ibl].scales, 0, QK_K/16); + memset(y[ibl].qh, 0, QK_K/32); + memset(y[ibl].scales, 0, QK_K/32); float max_scale = 0; @@ -11454,15 +11545,15 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; float sigma2 = sumx2/QK_K; - for (int ib = 0; ib < QK_K/8; ++ib) { - const float * xb = xbl + 8*ib; - const float * qw = quant_weights + QK_K*ibl + 8*ib; - for (int i = 0; i < 8; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + for (int ib = 0; ib < QK_K/IQ1S_BLOCK_SIZE; ++ib) { + const float * xb = xbl + IQ1S_BLOCK_SIZE*ib; + const float * qw = quant_weights + QK_K*ibl + IQ1S_BLOCK_SIZE*ib; + for (int i = 0; i < IQ1S_BLOCK_SIZE; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); float max = fabsf(xb[0]); - for (int i = 1; i < 8; ++i) max = MAX(max, fabsf(xb[i])); + for (int i = 1; i < IQ1S_BLOCK_SIZE; ++i) max = MAX(max, fabsf(xb[i])); if (!max) { scales[ib] = 0; - memset(L, 1, 8); + memset(L, 1, IQ1S_BLOCK_SIZE); continue; } // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem. @@ -11471,14 +11562,14 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale // for each possible and score for each split. - for (int j = 0; j < 8; ++j) { + for (int j = 0; j < IQ1S_BLOCK_SIZE; ++j) { pairs[2*j] = xb[j]; idx[2*j] = j; } - qsort(pairs, 8, 2*sizeof(float), iq1_sort_helper); + qsort(pairs, IQ1S_BLOCK_SIZE, 2*sizeof(float), iq1_sort_helper); { sumx[0] = sumw[0] = 0; - for (int j = 0; j < 8; ++j) { + for (int j = 0; j < IQ1S_BLOCK_SIZE; ++j) { int i = idx[2*j]; sumx[j+1] = sumx[j] + weight[i]*xb[i]; sumw[j+1] = sumw[j] + weight[i]; @@ -11486,10 +11577,10 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy } float best_score = 0, scale = max; int besti1 = 0, besti2 = 0; - for (int i1 = 0; i1 <= 8; ++i1) { - for (int i2 = i1; i2 <= 8; ++i2) { - float sumqx = -(sumx[i1] - sumx[0]) + (sumx[8] - sumx[i2]); - float sumq2 = (sumw[i1] - sumw[0]) + (sumw[8] - sumw[i2]); + for (int i1 = 0; i1 <= IQ1S_BLOCK_SIZE; ++i1) { + for (int i2 = i1; i2 <= IQ1S_BLOCK_SIZE; ++i2) { + float sumqx = -(sumx[i1] - sumx[0]) + (sumx[IQ1S_BLOCK_SIZE] - sumx[i2]); + float sumq2 = (sumw[i1] - sumw[0]) + (sumw[IQ1S_BLOCK_SIZE] - sumw[i2]); if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) { scale = sumqx/sumq2; best_score = scale*sumqx; besti1 = i1; besti2 = i2; @@ -11498,23 +11589,41 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy } for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0; for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1; - for (int j = besti2; j < 8; ++j) L[idx[2*j]] = 2; + for (int j = besti2; j < IQ1S_BLOCK_SIZE; ++j) L[idx[2*j]] = 2; if (scale < 0) { - for (int j = 0; j < 8; ++j) L[j] = 2 - L[j]; + for (int j = 0; j < IQ1S_BLOCK_SIZE; ++j) L[j] = 2 - L[j]; scale = -scale; } - // Now we check if the solution found above corresponds to a grid point and, if not, use a neighbouring - // grid point that minimizes SSD. - uint16_t u = 0; - for (int j = 0; j < 8; ++j) u |= (L[j] << 2*j); - int grid_index = kmap_q2xs[u]; - if (grid_index < 0) { - const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; - grid_index = iq1_find_best_neighbour(neighbours, kgrid_q2xs, xb, weight, &scale, L, NGRID_IQ2XXS); - GGML_ASSERT(grid_index >= 0); + bool all_on_grid = true; + for (int k = 0; k < IQ1S_BLOCK_SIZE/8; ++k) { + uint16_t u = 0; + for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j); + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + all_on_grid = false; + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, L + 8*k, NGRID_IQ1S); + GGML_ASSERT(grid_index >= 0); + } + index[k] = grid_index; } - y[ibl].qs[ib] = grid_index & 255; - hbit[ib] = grid_index >> 8; + if (!all_on_grid) { + float sumqx = 0, sumq2 = 0; + for (int k = 0; k < IQ1S_BLOCK_SIZE/8; ++k) { + const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]); + for (int j = 0; j < 8; ++j) { + float w = weight[8*k + j]; + float q = (pg[j] - 3)/2; + sumqx += w*q*xb[8*k+j]; + sumq2 += w*q*q; + } + } + if (sumqx > 0 && sumq2 > 0) scale = sumqx/sumq2; + } + y[ibl].qs[2*ib+0] = index[0] & 255; + y[ibl].qs[2*ib+1] = index[1] & 255; + if (ib%2 == 0) y[ibl].qh[ib/2] = (index[0] >> 8) | ((index[1] >> 8) << 2); + else y[ibl].qh[ib/2] |= ((index[0] >> 8) | ((index[1] >> 8) << 2)) << 4; GGML_ASSERT(scale >= 0); scales[ib] = scale; max_scale = MAX(max_scale, scale); @@ -11525,14 +11634,15 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy continue; } - float d = max_scale/15; + float d = max_scale/31; y[ibl].d = GGML_FP32_TO_FP16(d*1.085f); // 1.085f is another fudge factor. Don't ask me why it is needed. float id = 1/d; - for (int ib = 0; ib < QK_K/8; ++ib) { - int l = nearest_int(0.5f*(id*scales[ib]-1)); - l = MAX(0, MIN(7, l)); - if (hbit[ib]) l |= 8; - y[ibl].scales[ib/2] |= (l << 4*(ib%2)); + for (int ib = 0; ib < QK_K/IQ1S_BLOCK_SIZE; ib += 2) { + int l1 = nearest_int(0.5f*(id*scales[ib+0]-1)); + l1 = MAX(0, MIN(15, l1)); + int l2 = nearest_int(0.5f*(id*scales[ib+1]-1)); + l2 = MAX(0, MIN(15, l2)); + y[ibl].scales[ib/2] = l1 | (l2 << 4); } } } diff --git a/ggml-quants.h b/ggml-quants.h index 47dd52856..68902de9c 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -218,7 +218,8 @@ static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + IQ3S_N typedef struct { ggml_fp16_t d; uint8_t qs[QK_K/8]; - uint8_t scales[QK_K/16]; + uint8_t qh[QK_K/32]; + uint8_t scales[QK_K/32]; } block_iq1_s; static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");