diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 203bfe8cc..cfee120c3 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -1,6 +1,7 @@ #include "ggml.h" #include "llama.h" #include "llama_internal.h" +#include "ggml_extra.h" #include #include @@ -29,7 +30,7 @@ struct quantize_stats_params { std::vector include_types; }; -const int64_t SCRATCH_ELEMENTS = 32*32; +const int64_t SCRATCH_ELEMENTS = 32*32*256; // So we use multi-threading in a meaningful way in the new quantization const size_t HISTOGRAM_BUCKETS = 150; const double HISTOGRAM_RANGE = 0.03; @@ -184,6 +185,7 @@ int main(int argc, char ** argv) { // read command line bool invalid_param = false; + bool checkNewQuantization = false; std::string arg; for (int i = 1; i < argc; i++) { arg = argv[i]; @@ -232,6 +234,8 @@ int main(int argc, char ** argv) { fprintf(stderr, "error: %s not in list of types\n", argv[i]); invalid_param = true; } + } else if (arg == "-nq" || arg == "--new-quantization") { + checkNewQuantization = true; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); quantize_stats_print_usage(argc, argv); @@ -307,6 +311,9 @@ int main(int argc, char ** argv) { continue; } quantize_fns_t qfns = ggml_internal_get_quantize_fn(i); + if (i < 2 && checkNewQuantization) { + qfns.quantize_row_q = i == 0 ? kQuantizeQ4_0 : kQuantizeQ4_1; + } if (qfns.quantize_row_q && qfns.dequantize_row_q) { if (params.verbose) { printf("testing %s ...\n", type_strs[i]);