Trying blocvks of 16 for IQ1_S - seems slightly better
This commit is contained in:
parent
ef3ced26a3
commit
c9e9acf2be
3 changed files with 190 additions and 79 deletions
|
@ -566,7 +566,8 @@ static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + IQ3S_N
|
|||
typedef struct {
|
||||
half d;
|
||||
uint8_t qs[QK_K/8];
|
||||
uint8_t scales[QK_K/16];
|
||||
uint8_t qh[QK_K/32];
|
||||
uint8_t scales[QK_K/32];
|
||||
} block_iq1_s;
|
||||
static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
|
||||
|
||||
|
@ -1723,9 +1724,8 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
|
|||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
const int i8 = 4*ib+il;
|
||||
uint8_t h = x[i].scales[i8/2] >> 4*(i8%2);
|
||||
const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | ((h & 8) << 5)));
|
||||
const float d = (float)x[i].d * (2*(h & 7) + 1);
|
||||
const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | (((x[i].qh[i8/4] >> 2*(i8%4)) & 3) << 8)));
|
||||
const float d = (float)x[i].d * (2*((x[i].scales[ib] >> 4*(il/2)) & 0xf) + 1);
|
||||
for (int j = 0; j < 8; ++j) y[j] = d * grid[j];
|
||||
#else
|
||||
assert(false);
|
||||
|
|
258
ggml-quants.c
258
ggml-quants.c
|
@ -9996,7 +9996,7 @@ static inline int iq2_grid_size(enum ggml_type type) {
|
|||
GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ2_S);
|
||||
return type == GGML_TYPE_IQ2_XXS ? 256 :
|
||||
type == GGML_TYPE_IQ2_XS ? 512 :
|
||||
type == GGML_TYPE_IQ1_S ? 512 : 1024;
|
||||
type == GGML_TYPE_IQ1_S ? NGRID_IQ1S : 1024;
|
||||
}
|
||||
|
||||
static int iq2_compare_func(const void * left, const void * right) {
|
||||
|
@ -10063,39 +10063,71 @@ void iq2xs_init_impl(enum ggml_type type) {
|
|||
40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048,
|
||||
42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690,
|
||||
};
|
||||
static const uint16_t kgrid_1bit_512[512] = {
|
||||
10, 33, 41, 85, 132, 134, 160, 162, 277, 337, 340, 345, 357, 405, 516, 545,
|
||||
553, 598, 641, 650, 681, 1042, 1044, 1097, 1169, 1176, 1320, 1345, 1365, 1378, 1434, 1444,
|
||||
1545, 1617, 1642, 1685, 2053, 2080, 2089, 2133, 2176, 2182, 2208, 2214, 2306, 2384, 2393, 2440,
|
||||
2453, 2581, 2664, 2690, 2721, 4117, 4161, 4182, 4184, 4261, 4357, 4369, 4372, 4377, 4390, 4422,
|
||||
4432, 4437, 4449, 4457, 4485, 4497, 4505, 4629, 4677, 4696, 4774, 5205, 5217, 5225, 5386, 5397,
|
||||
5409, 5445, 5457, 5460, 5461, 5462, 5465, 5472, 5477, 5525, 5545, 5650, 5668, 5717, 5729, 5769,
|
||||
5777, 6212, 6234, 6244, 6293, 6424, 6482, 6485, 6502, 6505, 6529, 6538, 6565, 6656, 6682, 6788,
|
||||
6806, 6820, 8218, 8224, 8226, 8232, 8277, 8326, 8354, 8469, 8521, 8530, 8549, 8596, 8737, 8794,
|
||||
9221, 9253, 9348, 9369, 9380, 9474, 9557, 9633, 9732, 9753, 9793, 9830, 9862, 9880, 10240, 10272,
|
||||
10282, 10321, 10406, 10517, 10530, 10566, 10585, 10645, 10896, 16466, 16468, 16473, 16485, 16646, 16660, 16665,
|
||||
16725, 16793, 16806, 16914, 16969, 16977, 16996, 17028, 17057, 17408, 17416, 17434, 17493, 17512, 17578, 17685,
|
||||
17696, 17733, 17745, 17748, 17749, 17750, 17753, 17765, 17794, 17813, 17946, 17984, 18005, 18072, 18453, 18529,
|
||||
18569, 18722, 18756, 18762, 18773, 18794, 18833, 18853, 18945, 19026, 19033, 19077, 20489, 20497, 20500, 20517,
|
||||
20565, 20586, 20610, 20633, 20757, 20769, 20776, 20805, 20817, 20820, 20821, 20822, 20825, 20837, 20864, 20872,
|
||||
20885, 20896, 21002, 21029, 21077, 21146, 21510, 21525, 21573, 21585, 21588, 21589, 21590, 21593, 21605, 21653,
|
||||
21665, 21765, 21777, 21780, 21781, 21782, 21785, 21797, 21825, 21828, 21829, 21830, 21833, 21840, 21841, 21842,
|
||||
21844, 21846, 21848, 21849, 21850, 21857, 21860, 21861, 21862, 21865, 21893, 21905, 21908, 21909, 21910, 21913,
|
||||
21925, 22024, 22037, 22085, 22097, 22100, 22101, 22102, 22105, 22117, 22165, 22545, 22566, 22568, 22594, 22608,
|
||||
22613, 22676, 22697, 22793, 22805, 22853, 22865, 22868, 22869, 22870, 22873, 22885, 22933, 22946, 23046, 23072,
|
||||
23125, 23209, 24597, 24640, 24665, 24673, 24725, 24833, 24840, 24869, 24917, 24934, 24965, 25001, 25108, 25110,
|
||||
25152, 25184, 25192, 25234, 25616, 25618, 25625, 25685, 25704, 25738, 25744, 25770, 25877, 25897, 25925, 25937,
|
||||
25940, 25941, 25942, 25945, 25957, 25986, 26005, 26186, 26197, 26276, 26632, 26634, 26725, 26757, 26770, 26885,
|
||||
26965, 26976, 26986, 27032, 27153, 27174, 27200, 27208, 27240, 27269, 27282, 27290, 32778, 32800, 32802, 32808,
|
||||
32810, 32853, 32904, 32922, 32930, 32932, 33105, 33110, 33112, 33125, 33157, 33280, 33288, 33301, 33312, 33320,
|
||||
33424, 33797, 33829, 33858, 34068, 34133, 34146, 34176, 34217, 34306, 34342, 34441, 34454, 34468, 34832, 34918,
|
||||
34965, 34984, 35094, 35137, 35161, 35208, 35232, 35332, 35338, 35368, 35429, 36932, 36934, 36953, 37009, 37125,
|
||||
37136, 37138, 37145, 37157, 37205, 37220, 37258, 37290, 37444, 37446, 37465, 37478, 37525, 37905, 37968, 37973,
|
||||
38040, 38054, 38145, 38154, 38165, 38180, 38186, 38213, 38225, 38228, 38229, 38230, 38233, 38245, 38293, 38485,
|
||||
38504, 38530, 38938, 38985, 38993, 39012, 39040, 39173, 39192, 39253, 39265, 39301, 39316, 39322, 39442, 39497,
|
||||
39504, 39590, 40970, 40984, 40992, 41002, 41045, 41120, 41128, 41237, 41289, 41297, 41317, 41364, 41366, 41514,
|
||||
41557, 41633, 41989, 42021, 42056, 42068, 42074, 42113, 42242, 42265, 42274, 42325, 42340, 42402, 42501, 42512,
|
||||
42533, 42624, 42632, 42666, 43040, 43093, 43106, 43168, 43176, 43264, 43286, 43345, 43429, 43590, 43618, 43680,
|
||||
static const uint16_t kgrid_1bit_512[NGRID_IQ1S] = {
|
||||
5, 32, 40, 89, 101, 128, 138, 149, 160, 162, 170, 273, 281, 294, 329, 336,
|
||||
338, 341, 344, 346, 353, 356, 389, 401, 404, 409, 421, 517, 552, 584, 586, 593,
|
||||
640, 642, 661, 672, 674, 1108, 1110, 1160, 1169, 1192, 1286, 1301, 1306, 1313, 1349, 1361,
|
||||
1365, 1369, 1381, 1429, 1441, 1449, 1536, 1561, 1620, 1622, 1669, 1704, 1706, 2048, 2080, 2082,
|
||||
2122, 2129, 2176, 2178, 2197, 2208, 2210, 2326, 2329, 2341, 2369, 2384, 2389, 2401, 2404, 2409,
|
||||
2469, 2562, 2570, 2581, 2592, 2594, 2600, 2602, 2629, 2649, 2661, 2696, 2698, 2705, 4113, 4133,
|
||||
4181, 4186, 4193, 4226, 4249, 4261, 4353, 4356, 4358, 4361, 4370, 4373, 4378, 4385, 4393, 4416,
|
||||
4421, 4433, 4437, 4441, 4448, 4450, 4453, 4484, 4489, 4502, 4516, 4625, 4645, 4692, 4694, 4705,
|
||||
4753, 4773, 5141, 5153, 5161, 5190, 5193, 5201, 5205, 5208, 5216, 5221, 5253, 5268, 5393, 5397,
|
||||
5398, 5401, 5410, 5412, 5442, 5444, 5445, 5450, 5457, 5460, 5461, 5462, 5465, 5473, 5477, 5480,
|
||||
5482, 5504, 5510, 5521, 5525, 5528, 5541, 5653, 5697, 5702, 5705, 5712, 5714, 5717, 5720, 5722,
|
||||
5732, 5734, 5737, 5781, 6146, 6152, 6181, 6186, 6213, 6228, 6230, 6233, 6241, 6289, 6309, 6314,
|
||||
6405, 6420, 6470, 6473, 6481, 6485, 6488, 6490, 6496, 6501, 6533, 6548, 6550, 6553, 6561, 6664,
|
||||
6678, 6741, 6753, 6786, 8194, 8213, 8261, 8281, 8294, 8328, 8330, 8337, 8360, 8464, 8472, 8485,
|
||||
8533, 8545, 8548, 8581, 8596, 8613, 8725, 8738, 8776, 8785, 8793, 8805, 8832, 8834, 8858, 8864,
|
||||
8866, 8872, 9226, 9236, 9253, 9301, 9321, 9381, 9477, 9505, 9542, 9545, 9553, 9556, 9557, 9562,
|
||||
9573, 9622, 9633, 9641, 9728, 9730, 9738, 9770, 9813, 10261, 10272, 10274, 10304, 10321, 10344, 10370,
|
||||
10376, 10378, 10400, 10402, 10521, 10533, 10576, 10578, 10581, 10598, 10661, 10769, 10856, 10888, 10890, 10897,
|
||||
16384, 16457, 16469, 16472, 16484, 16529, 16646, 16649, 16661, 16664, 16666, 16681, 16709, 16721, 16724, 16725,
|
||||
16726, 16729, 16741, 16746, 16769, 16784, 16789, 16804, 16809, 16918, 16928, 16961, 17001, 17033, 17041, 17425,
|
||||
17428, 17430, 17445, 17473, 17476, 17478, 17490, 17493, 17510, 17513, 17541, 17556, 17558, 17573, 17665, 17668,
|
||||
17680, 17682, 17685, 17689, 17728, 17730, 17733, 17736, 17738, 17745, 17748, 17749, 17750, 17753, 17762, 17765,
|
||||
17768, 17796, 17798, 17809, 17813, 17817, 17830, 17929, 17937, 17940, 17957, 17989, 18000, 18002, 18005, 18008,
|
||||
18010, 18017, 18020, 18022, 18049, 18068, 18070, 18085, 18472, 18512, 18517, 18577, 18694, 18709, 18721, 18757,
|
||||
18772, 18773, 18778, 18784, 18789, 18817, 18825, 18837, 18849, 18949, 18966, 19017, 19029, 20501, 20513, 20521,
|
||||
20545, 20550, 20564, 20565, 20569, 20581, 20629, 20741, 20753, 20757, 20758, 20773, 20776, 20805, 20808, 20816,
|
||||
20817, 20818, 20820, 20821, 20822, 20825, 20833, 20836, 20837, 20838, 20841, 20870, 20881, 20884, 20885, 20889,
|
||||
20901, 21001, 21012, 21060, 21062, 21073, 21077, 21080, 21082, 21141, 21509, 21520, 21522, 21525, 21528, 21530,
|
||||
21540, 21542, 21568, 21570, 21573, 21576, 21578, 21585, 21588, 21589, 21590, 21593, 21602, 21605, 21608, 21610,
|
||||
21636, 21638, 21641, 21648, 21650, 21653, 21656, 21665, 21668, 21670, 21673, 21760, 21762, 21765, 21768, 21770,
|
||||
21777, 21780, 21781, 21782, 21785, 21793, 21797, 21802, 21825, 21828, 21829, 21830, 21833, 21840, 21841, 21842,
|
||||
21844, 21845, 21846, 21848, 21849, 21850, 21857, 21860, 21861, 21862, 21865, 21889, 21893, 21896, 21898, 21905,
|
||||
21908, 21909, 21910, 21913, 21920, 21922, 21925, 21928, 21930, 22017, 22020, 22032, 22034, 22037, 22042, 22052,
|
||||
22054, 22057, 22080, 22082, 22085, 22088, 22090, 22097, 22100, 22101, 22102, 22105, 22112, 22114, 22117, 22120,
|
||||
22122, 22148, 22150, 22160, 22162, 22165, 22168, 22170, 22177, 22180, 22182, 22185, 22548, 22550, 22561, 22598,
|
||||
22601, 22609, 22613, 22616, 22618, 22624, 22630, 22633, 22677, 22793, 22801, 22805, 22808, 22810, 22825, 22849,
|
||||
22852, 22853, 22858, 22865, 22866, 22868, 22869, 22870, 22873, 22884, 22885, 22890, 22912, 22918, 22929, 22933,
|
||||
22936, 22938, 22950, 22953, 23060, 23065, 23077, 23110, 23121, 23125, 23130, 23142, 23145, 23169, 23188, 23190,
|
||||
23205, 24581, 24593, 24596, 24601, 24661, 24664, 24709, 24726, 24729, 24833, 24853, 24865, 24868, 24870, 24873,
|
||||
24900, 24902, 24913, 24917, 24921, 24933, 24938, 24981, 24993, 24996, 25001, 25105, 25173, 25188, 25221, 25233,
|
||||
25253, 25621, 25633, 25641, 25669, 25680, 25682, 25685, 25689, 25701, 25749, 25860, 25862, 25865, 25873, 25877,
|
||||
25882, 25896, 25920, 25922, 25925, 25928, 25930, 25937, 25940, 25941, 25942, 25945, 25952, 25957, 25958, 25988,
|
||||
25990, 25993, 26001, 26005, 26021, 26117, 26132, 26134, 26137, 26149, 26177, 26180, 26182, 26185, 26192, 26194,
|
||||
26197, 26200, 26202, 26209, 26217, 26260, 26262, 26265, 26625, 26649, 26709, 26757, 26769, 26789, 26896, 26901,
|
||||
26918, 26950, 26953, 26965, 26968, 26970, 26977, 27024, 27026, 27029, 27044, 27205, 27220, 27222, 27225, 27237,
|
||||
27306, 32770, 32776, 32778, 32789, 32800, 32802, 32808, 32810, 32837, 32849, 32854, 32857, 32869, 32896, 32898,
|
||||
32904, 32906, 32917, 32928, 32936, 32938, 33029, 33046, 33089, 33106, 33109, 33121, 33124, 33126, 33169, 33172,
|
||||
33174, 33189, 33282, 33314, 33322, 33352, 33354, 33408, 33429, 33448, 33450, 33817, 33872, 33877, 33945, 33954,
|
||||
34054, 34081, 34086, 34116, 34121, 34129, 34133, 34136, 34138, 34153, 34177, 34194, 34212, 34304, 34325, 34344,
|
||||
34388, 34390, 34393, 34405, 34437, 34821, 34848, 34850, 34888, 34890, 34922, 34944, 34946, 34965, 34976, 35089,
|
||||
35092, 35109, 35152, 35157, 35172, 35336, 35338, 35360, 35394, 35409, 35426, 35456, 35464, 35466, 35496, 35498,
|
||||
36881, 36889, 36901, 36949, 36997, 37009, 37012, 37014, 37029, 37121, 37124, 37153, 37161, 37189, 37204, 37205,
|
||||
37209, 37221, 37269, 37274, 37284, 37397, 37462, 37481, 37506, 37536, 37538, 37889, 37909, 37956, 37958, 37961,
|
||||
37973, 37976, 37978, 37988, 37990, 37993, 38037, 38161, 38165, 38170, 38180, 38208, 38213, 38216, 38218, 38225,
|
||||
38226, 38228, 38229, 38230, 38233, 38241, 38245, 38248, 38250, 38277, 38288, 38293, 38310, 38313, 38405, 38420,
|
||||
38422, 38425, 38437, 38465, 38468, 38470, 38473, 38480, 38485, 38490, 38500, 38502, 38545, 38548, 38550, 38553,
|
||||
38565, 38929, 38937, 38977, 38994, 38996, 39013, 39045, 39057, 39080, 39169, 39172, 39184, 39186, 39189, 39201,
|
||||
39238, 39241, 39253, 39258, 39264, 39270, 39316, 39318, 39321, 39333, 39466, 39493, 39510, 39512, 39525, 39573,
|
||||
39584, 40960, 40962, 40968, 40970, 40981, 40992, 41000, 41002, 41029, 41049, 41061, 41096, 41109, 41128, 41221,
|
||||
41236, 41238, 41241, 41253, 41298, 41301, 41304, 41306, 41313, 41321, 41361, 41364, 41369, 41480, 41482, 41512,
|
||||
41514, 41608, 41610, 41640, 41642, 42021, 42065, 42068, 42070, 42112, 42114, 42122, 42137, 42144, 42146, 42154,
|
||||
42261, 42264, 42274, 42281, 42305, 42308, 42310, 42313, 42320, 42325, 42390, 42392, 42496, 42498, 42528, 42565,
|
||||
42577, 42580, 42582, 42585, 42597, 42624, 42645, 43014, 43016, 43029, 43040, 43048, 43050, 43097, 43144, 43157,
|
||||
43284, 43289, 43301, 43333, 43345, 43350, 43369, 43528, 43530, 43560, 43605, 43650, 43656, 43658, 43682, 43688,
|
||||
};
|
||||
static const uint16_t kgrid_2bit_1024[1024] = {
|
||||
0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70,
|
||||
|
@ -11408,12 +11440,70 @@ static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const u
|
|||
return grid_index;
|
||||
}
|
||||
|
||||
static int iq1_find_best_neighbour2(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
|
||||
const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L, int ngrid) {
|
||||
int num_neighbors = neighbours[0];
|
||||
GGML_ASSERT(num_neighbors > 0);
|
||||
float best_score = FLT_MAX;
|
||||
int grid_index = -1;
|
||||
for (int j = 1; j <= num_neighbors; ++j) {
|
||||
const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
|
||||
float d2 = 0;
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
float q = (pg[i] - 3)/2;
|
||||
float w = weight[i];
|
||||
float diff = scale*q - xval[i];
|
||||
d2 += w*diff*diff;
|
||||
}
|
||||
if (d2 < best_score) {
|
||||
best_score = d2;
|
||||
grid_index = neighbours[j];
|
||||
}
|
||||
}
|
||||
if (grid_index < 0) {
|
||||
for (int i = 0; i < ngrid; ++i) {
|
||||
const int8_t * grid_i = (const int8_t *)(grid + i);
|
||||
float d2 = 0;
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
float w = weight[j];
|
||||
float q = (grid_i[j] - 3)/2;
|
||||
float diff = scale*q - xval[i];
|
||||
d2 += w*diff*diff;
|
||||
}
|
||||
if (d2 < best_score) {
|
||||
best_score = d2;
|
||||
grid_index = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (grid_index < 0) {
|
||||
printf("Oops, did not find grid point\n");
|
||||
printf("Have %d neighbours\n", num_neighbors);
|
||||
for (int j = 1; j <= num_neighbors; ++j) {
|
||||
const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
|
||||
float sumqx = 0, sumq2 = 0;
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
float q = (pg[i] - 3)/2;
|
||||
float w = weight[i];
|
||||
sumqx += w*q*xval[i];
|
||||
sumq2 += w*q*q;
|
||||
}
|
||||
printf(" neighbour %d: sumqx = %g sumq2 = %g\n", j, (double)sumqx, (double)sumq2);
|
||||
}
|
||||
}
|
||||
GGML_ASSERT(grid_index >= 0);
|
||||
const int8_t * pg = (const int8_t *)(grid + grid_index);
|
||||
for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;
|
||||
return grid_index;
|
||||
}
|
||||
|
||||
static int iq1_sort_helper(const void * left, const void * right) {
|
||||
const float * l = left;
|
||||
const float * r = right;
|
||||
return *l < *r ? -1 : *l > *r ? 1 : 0;
|
||||
}
|
||||
|
||||
#define IQ1S_BLOCK_SIZE 16
|
||||
static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
|
||||
|
||||
const int gindex = iq2_data_index(GGML_TYPE_IQ1_S);
|
||||
|
@ -11432,20 +11522,21 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
|
|||
|
||||
block_iq1_s * y = vy;
|
||||
|
||||
float scales[QK_K/8];
|
||||
float weight[8];
|
||||
int8_t L[8];
|
||||
float sumx[9];
|
||||
float sumw[9];
|
||||
float pairs[16];
|
||||
float scales[QK_K/IQ1S_BLOCK_SIZE];
|
||||
float weight[IQ1S_BLOCK_SIZE];
|
||||
int8_t L[IQ1S_BLOCK_SIZE];
|
||||
float sumx[IQ1S_BLOCK_SIZE+1];
|
||||
float sumw[IQ1S_BLOCK_SIZE+1];
|
||||
float pairs[2*IQ1S_BLOCK_SIZE];
|
||||
int * idx = (int *)(pairs + 1);
|
||||
uint8_t hbit[QK_K/8];
|
||||
uint16_t index[IQ1S_BLOCK_SIZE/8];
|
||||
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) {
|
||||
|
||||
y[ibl].d = GGML_FP32_TO_FP16(0.f);
|
||||
memset(y[ibl].qs, 0, QK_K/8);
|
||||
memset(y[ibl].scales, 0, QK_K/16);
|
||||
memset(y[ibl].qh, 0, QK_K/32);
|
||||
memset(y[ibl].scales, 0, QK_K/32);
|
||||
|
||||
float max_scale = 0;
|
||||
|
||||
|
@ -11454,15 +11545,15 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
|
|||
for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
|
||||
float sigma2 = sumx2/QK_K;
|
||||
|
||||
for (int ib = 0; ib < QK_K/8; ++ib) {
|
||||
const float * xb = xbl + 8*ib;
|
||||
const float * qw = quant_weights + QK_K*ibl + 8*ib;
|
||||
for (int i = 0; i < 8; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
|
||||
for (int ib = 0; ib < QK_K/IQ1S_BLOCK_SIZE; ++ib) {
|
||||
const float * xb = xbl + IQ1S_BLOCK_SIZE*ib;
|
||||
const float * qw = quant_weights + QK_K*ibl + IQ1S_BLOCK_SIZE*ib;
|
||||
for (int i = 0; i < IQ1S_BLOCK_SIZE; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
|
||||
float max = fabsf(xb[0]);
|
||||
for (int i = 1; i < 8; ++i) max = MAX(max, fabsf(xb[i]));
|
||||
for (int i = 1; i < IQ1S_BLOCK_SIZE; ++i) max = MAX(max, fabsf(xb[i]));
|
||||
if (!max) {
|
||||
scales[ib] = 0;
|
||||
memset(L, 1, 8);
|
||||
memset(L, 1, IQ1S_BLOCK_SIZE);
|
||||
continue;
|
||||
}
|
||||
// Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
|
||||
|
@ -11471,14 +11562,14 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
|
|||
// in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and
|
||||
// Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale
|
||||
// for each possible and score for each split.
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
for (int j = 0; j < IQ1S_BLOCK_SIZE; ++j) {
|
||||
pairs[2*j] = xb[j];
|
||||
idx[2*j] = j;
|
||||
}
|
||||
qsort(pairs, 8, 2*sizeof(float), iq1_sort_helper);
|
||||
qsort(pairs, IQ1S_BLOCK_SIZE, 2*sizeof(float), iq1_sort_helper);
|
||||
{
|
||||
sumx[0] = sumw[0] = 0;
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
for (int j = 0; j < IQ1S_BLOCK_SIZE; ++j) {
|
||||
int i = idx[2*j];
|
||||
sumx[j+1] = sumx[j] + weight[i]*xb[i];
|
||||
sumw[j+1] = sumw[j] + weight[i];
|
||||
|
@ -11486,10 +11577,10 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
|
|||
}
|
||||
float best_score = 0, scale = max;
|
||||
int besti1 = 0, besti2 = 0;
|
||||
for (int i1 = 0; i1 <= 8; ++i1) {
|
||||
for (int i2 = i1; i2 <= 8; ++i2) {
|
||||
float sumqx = -(sumx[i1] - sumx[0]) + (sumx[8] - sumx[i2]);
|
||||
float sumq2 = (sumw[i1] - sumw[0]) + (sumw[8] - sumw[i2]);
|
||||
for (int i1 = 0; i1 <= IQ1S_BLOCK_SIZE; ++i1) {
|
||||
for (int i2 = i1; i2 <= IQ1S_BLOCK_SIZE; ++i2) {
|
||||
float sumqx = -(sumx[i1] - sumx[0]) + (sumx[IQ1S_BLOCK_SIZE] - sumx[i2]);
|
||||
float sumq2 = (sumw[i1] - sumw[0]) + (sumw[IQ1S_BLOCK_SIZE] - sumw[i2]);
|
||||
if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
|
||||
scale = sumqx/sumq2; best_score = scale*sumqx;
|
||||
besti1 = i1; besti2 = i2;
|
||||
|
@ -11498,23 +11589,41 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
|
|||
}
|
||||
for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0;
|
||||
for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
|
||||
for (int j = besti2; j < 8; ++j) L[idx[2*j]] = 2;
|
||||
for (int j = besti2; j < IQ1S_BLOCK_SIZE; ++j) L[idx[2*j]] = 2;
|
||||
if (scale < 0) {
|
||||
for (int j = 0; j < 8; ++j) L[j] = 2 - L[j];
|
||||
for (int j = 0; j < IQ1S_BLOCK_SIZE; ++j) L[j] = 2 - L[j];
|
||||
scale = -scale;
|
||||
}
|
||||
// Now we check if the solution found above corresponds to a grid point and, if not, use a neighbouring
|
||||
// grid point that minimizes SSD.
|
||||
uint16_t u = 0;
|
||||
for (int j = 0; j < 8; ++j) u |= (L[j] << 2*j);
|
||||
int grid_index = kmap_q2xs[u];
|
||||
if (grid_index < 0) {
|
||||
const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
|
||||
grid_index = iq1_find_best_neighbour(neighbours, kgrid_q2xs, xb, weight, &scale, L, NGRID_IQ2XXS);
|
||||
GGML_ASSERT(grid_index >= 0);
|
||||
bool all_on_grid = true;
|
||||
for (int k = 0; k < IQ1S_BLOCK_SIZE/8; ++k) {
|
||||
uint16_t u = 0;
|
||||
for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j);
|
||||
int grid_index = kmap_q2xs[u];
|
||||
if (grid_index < 0) {
|
||||
all_on_grid = false;
|
||||
const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
|
||||
grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, L + 8*k, NGRID_IQ1S);
|
||||
GGML_ASSERT(grid_index >= 0);
|
||||
}
|
||||
index[k] = grid_index;
|
||||
}
|
||||
y[ibl].qs[ib] = grid_index & 255;
|
||||
hbit[ib] = grid_index >> 8;
|
||||
if (!all_on_grid) {
|
||||
float sumqx = 0, sumq2 = 0;
|
||||
for (int k = 0; k < IQ1S_BLOCK_SIZE/8; ++k) {
|
||||
const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]);
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
float w = weight[8*k + j];
|
||||
float q = (pg[j] - 3)/2;
|
||||
sumqx += w*q*xb[8*k+j];
|
||||
sumq2 += w*q*q;
|
||||
}
|
||||
}
|
||||
if (sumqx > 0 && sumq2 > 0) scale = sumqx/sumq2;
|
||||
}
|
||||
y[ibl].qs[2*ib+0] = index[0] & 255;
|
||||
y[ibl].qs[2*ib+1] = index[1] & 255;
|
||||
if (ib%2 == 0) y[ibl].qh[ib/2] = (index[0] >> 8) | ((index[1] >> 8) << 2);
|
||||
else y[ibl].qh[ib/2] |= ((index[0] >> 8) | ((index[1] >> 8) << 2)) << 4;
|
||||
GGML_ASSERT(scale >= 0);
|
||||
scales[ib] = scale;
|
||||
max_scale = MAX(max_scale, scale);
|
||||
|
@ -11525,14 +11634,15 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
|
|||
continue;
|
||||
}
|
||||
|
||||
float d = max_scale/15;
|
||||
float d = max_scale/31;
|
||||
y[ibl].d = GGML_FP32_TO_FP16(d*1.085f); // 1.085f is another fudge factor. Don't ask me why it is needed.
|
||||
float id = 1/d;
|
||||
for (int ib = 0; ib < QK_K/8; ++ib) {
|
||||
int l = nearest_int(0.5f*(id*scales[ib]-1));
|
||||
l = MAX(0, MIN(7, l));
|
||||
if (hbit[ib]) l |= 8;
|
||||
y[ibl].scales[ib/2] |= (l << 4*(ib%2));
|
||||
for (int ib = 0; ib < QK_K/IQ1S_BLOCK_SIZE; ib += 2) {
|
||||
int l1 = nearest_int(0.5f*(id*scales[ib+0]-1));
|
||||
l1 = MAX(0, MIN(15, l1));
|
||||
int l2 = nearest_int(0.5f*(id*scales[ib+1]-1));
|
||||
l2 = MAX(0, MIN(15, l2));
|
||||
y[ibl].scales[ib/2] = l1 | (l2 << 4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -218,7 +218,8 @@ static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + IQ3S_N
|
|||
typedef struct {
|
||||
ggml_fp16_t d;
|
||||
uint8_t qs[QK_K/8];
|
||||
uint8_t scales[QK_K/16];
|
||||
uint8_t qh[QK_K/32];
|
||||
uint8_t scales[QK_K/32];
|
||||
} block_iq1_s;
|
||||
static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue