diff --git a/ggml.c b/ggml.c index 3c0e2b9aa..b4e808664 100644 --- a/ggml.c +++ b/ggml.c @@ -903,55 +903,22 @@ static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * r static inline unsigned char dQuantizeNF4(float x) { - - // the values for this tree was generated by test_normal_map_tree - // in the file tests/test_functional.py - if(x > 0.03979014977812767f) - if(x > 0.3893125355243683f) // 1 - if(x > 0.6427869200706482f) // 11 - if(x > 0.8614784181118011f) // 111 - return 0b1111; - else - return 0b1110; - else - if(x > 0.5016634166240692f) // 110 - return 0b1101; - else - return 0b1100; - else - if(x > 0.2035212516784668f) // 10 - if(x > 0.2920137718319893f) // 101 - return 0b1011; - else - return 0b1010; - else - if(x > 0.1202552504837513f) // 100 - return 0b1001; - else - return 0b1000; - else - if(x > -0.33967943489551544f) // 0 - if(x > -0.13791173323988914f) // 01 - if(x > -0.045525018125772476f) // 011 - return 0b0111; - else - return 0b0110; - else - if(x > -0.23460740596055984f) // 010 - return 0b0101; - else - return 0b0100; - else - if(x > -0.6106329262256622f) // 00 - if(x > -0.4599952697753906f) // 001 - return 0b0011; - else - return 0b0010; - else - if(x > -0.8480964004993439f) // 000 - return 0b0001; - else - return 0b0000; + if (x > 0.8614784181118011f) return 0b1111; + if (x > 0.6427869200706482f) return 0b1110; + if (x > 0.5016634166240692f) return 0b1101; + if (x > 0.3893125355243683f) return 0b1100; + if (x > 0.2920137718319893f) return 0b1011; + if (x > 0.2035212516784668f) return 0b1010; + if (x > 0.1202552504837513f) return 0b1001; + if (x > 0.03979014977812767f) return 0b1000; + if (x > -0.045525018125772476f) return 0b0111; + if (x > -0.13791173323988914f) return 0b0110; + if (x > -0.23460740596055984f) return 0b0101; + if (x > -0.33967943489551544f) return 0b0100; + if (x > -0.4599952697753906f) return 0b0011; + if (x > -0.6106329262256622f) return 0b0010; + if (x > -0.8480964004993439f) return 0b0001; + return 0b0000; } static void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) { @@ -1539,56 +1506,26 @@ static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict static inline float dhDequantizeNF4(unsigned char val) { - // the values for this tree was generated by test_normal_map_tree - // in the file tests/test_functional.py - if((val & 0b1000) == 8) - if((val & 0b0100) == 4) // 1 - if((val & 0b0010) == 2) // 11 - if((val & 0b0001) == 1) // 111 - return 1.0f; - else - return 0.7229568362236023f; - else - if((val & 0b0001) == 1) // 110 - return 0.5626170039176941f; - else - return 0.44070982933044434f; - else - if((val & 0b0010) == 2) //10 - if((val & 0b0001) == 1) // 101 - return 0.33791524171829224f; - else - return 0.24611230194568634f; - else - if((val & 0b0001) == 1) // 100 - return 0.16093020141124725f; - else - return 0.07958029955625534f; - - else - if((val & 0b0100) == 4) // 0 - if((val & 0b0010) == 2) //01 - if((val & 0b0001) == 1) // 011 - return 0.0f; - else - return -0.09105003625154495f; - else - if((val & 0b0001) == 1) // 010 - return -0.18477343022823334f; - else - return -0.28444138169288635f; - else - if((val & 0b0010) == 2) //00 - if((val & 0b0001) == 1) // 001 - return -0.39491748809814453f; - else - return -0.5250730514526367f; - else - if((val & 0b0001) == 1) // 000 - return -0.6961928009986877f; - else - return -1.0f; - + switch (val) + { + case 0b1111: return 1.0f; + case 0b1110: return 0.7229568362236023f; + case 0b1101: return 0.5626170039176941f; + case 0b1100: return 0.44070982933044434f; + case 0b1011: return 0.33791524171829224f; + case 0b1010: return 0.24611230194568634f; + case 0b1001: return 0.16093020141124725f; + case 0b1000: return 0.07958029955625534f; + case 0b0111: return 0.0f; + case 0b0110: return -0.09105003625154495f; + case 0b0101: return -0.18477343022823334f; + case 0b0100: return -0.28444138169288635f; + case 0b0011: return -0.39491748809814453f; + case 0b0010: return -0.5250730514526367f; + case 0b0001: return -0.6961928009986877f; + case 0b0000: return -1.0f; + default: return 0.0f; // Default case, you can change the return value as needed. + } } static void dequantize_row_q4_2(const block_q4_2 * restrict x, float * restrict y, int k) { diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp index a31a18827..580ee2ae4 100644 --- a/tests/test-quantize-fns.cpp +++ b/tests/test-quantize-fns.cpp @@ -11,7 +11,7 @@ const float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001; -const float MAX_QUANTIZATION_TOTAL_ERROR = 0.002; +const float MAX_QUANTIZATION_TOTAL_ERROR = 0.0022; const float MAX_DOT_PRODUCT_ERROR = 0.02; const char* RESULT_STR[] = {"ok", "FAILED"};