From d4915074c473c5c4fa9c538292bc1e569d2fd222 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A5kon=20H=2E=20Hitland?= Date: Tue, 4 Apr 2023 00:33:09 +0200 Subject: [PATCH] quantize-stats: misc improvements Show RMSE instead of MSE - keeps similar range to the other metrics. Regex match on layer pattern. --- examples/quantize-stats/quantize-stats.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 6d8cf95a6..6ec7ce67d 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -55,9 +56,9 @@ void quantize_stats_print_usage(int /*argc*/, char ** argv) { fprintf(stderr, " --histogram\n"); fprintf(stderr, " print error histogram (default: false)\n"); fprintf(stderr, " -l LAYER, --include-layer LAYER\n"); - fprintf(stderr, " only test layers containing substring\n"); + fprintf(stderr, " only test layers matching pattern\n"); fprintf(stderr, " -L LAYER, --exclude-layer LAYER\n"); - fprintf(stderr, " exclude layers containing substring\n"); + fprintf(stderr, " exclude layers matching pattern\n"); fprintf(stderr, " -t TYPE, --type TYPE\n"); fprintf(stderr, " only test given type (q4_0, q4_1)\n"); fprintf(stderr, "\n"); @@ -66,12 +67,12 @@ void quantize_stats_print_usage(int /*argc*/, char ** argv) { // Check if a layer is included/excluded by command line bool layer_included(const quantize_stats_params params, const std::string & layer) { for (const auto& excluded : params.exclude_layers) { - if (layer.find(excluded) != std::string::npos) { + if (std::regex_search(layer, std::regex(excluded))) { return false; } } for (const auto& included : params.include_layers) { - if (layer.find(included) != std::string::npos) { + if (std::regex_search(layer, std::regex(included))) { return true; } } @@ -103,10 +104,10 @@ double find_quantile(const error_stats & stats, double quantile) { } void print_error_stats(const std::string & name, const error_stats & stats, bool print_histogram) { - double rms = stats.total_error / (double) stats.num_samples; + double rmse = sqrt(stats.total_error / (double) stats.num_samples); double median = find_quantile(stats, .5); double pct95 = find_quantile(stats, .95); - printf("%-50s: mse %.8f, maxerr %.8f, 95pct<%.4f, median<%.4f\n", name.c_str(), rms, stats.max_error, pct95, median); + printf("%-50s: rmse %.8f, maxerr %.8f, 95pct<%.4f, median<%.4f\n", name.c_str(), rmse, stats.max_error, pct95, median); if (print_histogram) { printf("Error distribution:\n"); for (size_t i = 0; i < HISTOGRAM_BUCKETS; i++) {