quantize-stats: misc improvements

Show RMSE instead of MSE - keeps similar range to the other metrics.
Regex match on layer pattern.
This commit is contained in:
Håkon H. Hitland 2023-04-04 00:33:09 +02:00
parent a7d3c3f304
commit d4915074c4

View file

@ -11,6 +11,7 @@
#include <cstring>
#include <map>
#include <numeric>
#include <regex>
#include <string>
#include <unordered_map>
#include <vector>
@ -55,9 +56,9 @@ void quantize_stats_print_usage(int /*argc*/, char ** argv) {
fprintf(stderr, " --histogram\n");
fprintf(stderr, " print error histogram (default: false)\n");
fprintf(stderr, " -l LAYER, --include-layer LAYER\n");
fprintf(stderr, " only test layers containing substring\n");
fprintf(stderr, " only test layers matching pattern\n");
fprintf(stderr, " -L LAYER, --exclude-layer LAYER\n");
fprintf(stderr, " exclude layers containing substring\n");
fprintf(stderr, " exclude layers matching pattern\n");
fprintf(stderr, " -t TYPE, --type TYPE\n");
fprintf(stderr, " only test given type (q4_0, q4_1)\n");
fprintf(stderr, "\n");
@ -66,12 +67,12 @@ void quantize_stats_print_usage(int /*argc*/, char ** argv) {
// Check if a layer is included/excluded by command line
bool layer_included(const quantize_stats_params params, const std::string & layer) {
for (const auto& excluded : params.exclude_layers) {
if (layer.find(excluded) != std::string::npos) {
if (std::regex_search(layer, std::regex(excluded))) {
return false;
}
}
for (const auto& included : params.include_layers) {
if (layer.find(included) != std::string::npos) {
if (std::regex_search(layer, std::regex(included))) {
return true;
}
}
@ -103,10 +104,10 @@ double find_quantile(const error_stats & stats, double quantile) {
}
void print_error_stats(const std::string & name, const error_stats & stats, bool print_histogram) {
double rms = stats.total_error / (double) stats.num_samples;
double rmse = sqrt(stats.total_error / (double) stats.num_samples);
double median = find_quantile(stats, .5);
double pct95 = find_quantile(stats, .95);
printf("%-50s: mse %.8f, maxerr %.8f, 95pct<%.4f, median<%.4f\n", name.c_str(), rms, stats.max_error, pct95, median);
printf("%-50s: rmse %.8f, maxerr %.8f, 95pct<%.4f, median<%.4f\n", name.c_str(), rmse, stats.max_error, pct95, median);
if (print_histogram) {
printf("Error distribution:\n");
for (size_t i = 0; i < HISTOGRAM_BUCKETS; i++) {