quantize-stats: misc improvements
Show RMSE instead of MSE - keeps similar range to the other metrics. Regex match on layer pattern.
This commit is contained in:
parent
a7d3c3f304
commit
d4915074c4
1 changed files with 7 additions and 6 deletions
|
@ -11,6 +11,7 @@
|
|||
#include <cstring>
|
||||
#include <map>
|
||||
#include <numeric>
|
||||
#include <regex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
@ -55,9 +56,9 @@ void quantize_stats_print_usage(int /*argc*/, char ** argv) {
|
|||
fprintf(stderr, " --histogram\n");
|
||||
fprintf(stderr, " print error histogram (default: false)\n");
|
||||
fprintf(stderr, " -l LAYER, --include-layer LAYER\n");
|
||||
fprintf(stderr, " only test layers containing substring\n");
|
||||
fprintf(stderr, " only test layers matching pattern\n");
|
||||
fprintf(stderr, " -L LAYER, --exclude-layer LAYER\n");
|
||||
fprintf(stderr, " exclude layers containing substring\n");
|
||||
fprintf(stderr, " exclude layers matching pattern\n");
|
||||
fprintf(stderr, " -t TYPE, --type TYPE\n");
|
||||
fprintf(stderr, " only test given type (q4_0, q4_1)\n");
|
||||
fprintf(stderr, "\n");
|
||||
|
@ -66,12 +67,12 @@ void quantize_stats_print_usage(int /*argc*/, char ** argv) {
|
|||
// Check if a layer is included/excluded by command line
|
||||
bool layer_included(const quantize_stats_params params, const std::string & layer) {
|
||||
for (const auto& excluded : params.exclude_layers) {
|
||||
if (layer.find(excluded) != std::string::npos) {
|
||||
if (std::regex_search(layer, std::regex(excluded))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
for (const auto& included : params.include_layers) {
|
||||
if (layer.find(included) != std::string::npos) {
|
||||
if (std::regex_search(layer, std::regex(included))) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -103,10 +104,10 @@ double find_quantile(const error_stats & stats, double quantile) {
|
|||
}
|
||||
|
||||
void print_error_stats(const std::string & name, const error_stats & stats, bool print_histogram) {
|
||||
double rms = stats.total_error / (double) stats.num_samples;
|
||||
double rmse = sqrt(stats.total_error / (double) stats.num_samples);
|
||||
double median = find_quantile(stats, .5);
|
||||
double pct95 = find_quantile(stats, .95);
|
||||
printf("%-50s: mse %.8f, maxerr %.8f, 95pct<%.4f, median<%.4f\n", name.c_str(), rms, stats.max_error, pct95, median);
|
||||
printf("%-50s: rmse %.8f, maxerr %.8f, 95pct<%.4f, median<%.4f\n", name.c_str(), rmse, stats.max_error, pct95, median);
|
||||
if (print_histogram) {
|
||||
printf("Error distribution:\n");
|
||||
for (size_t i = 0; i < HISTOGRAM_BUCKETS; i++) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue