improve formatting

This commit is contained in:
slaren 2023-08-17 17:28:55 +02:00
parent 9c3866099b
commit b6c81e28cd

View file

@ -537,6 +537,25 @@ struct test {
return fields; return fields;
} }
enum field_type {STRING, BOOL, INT, FLOAT};
static field_type get_field_type(const std::string & field) {
if (field == "build_number" || field == "n_batch" || field == "n_threads" ||
field == "n_gpu_layers" || field == "main_gpu" ||
field == "n_prompt" || field == "n_gen" ||
field == "avg_ns" || field == "stddev_ns") {
return INT;
}
if (field == "cuda" || field == "opencl" || field == "metal" || field == "gpu_blas" || field == "blas" ||
field == "f16_kv" || field == "mul_mat_q" || field == "low_vram") {
return BOOL;
}
if (field == "avg_ts" || field == "stddev_ts") {
return FLOAT;
}
return STRING;
}
std::vector<std::string> get_values() const { std::vector<std::string> get_values() const {
std::string tensor_split_str; std::string tensor_split_str;
int max_nonzero = 0; int max_nonzero = 0;
@ -594,8 +613,19 @@ struct printer {
virtual void print_footer() { }; virtual void print_footer() { };
}; };
// TODO: escape strings
struct csv_printer : public printer { struct csv_printer : public printer {
static std::string escape_csv(const std::string & field) {
std::string escaped = "\"";
for (auto c : field) {
if (c == '"') {
escaped += "\"";
}
escaped += c;
}
escaped += "\"";
return escaped;
}
void print_header(const cmd_params & params) override { void print_header(const cmd_params & params) override {
std::vector<std::string> fields = test::get_fields(); std::vector<std::string> fields = test::get_fields();
fprintf(fout, "%s\n", join(fields, ",").c_str()); fprintf(fout, "%s\n", join(fields, ",").c_str());
@ -604,6 +634,7 @@ struct csv_printer : public printer {
void print_test(const test & t) override { void print_test(const test & t) override {
std::vector<std::string> values = t.get_values(); std::vector<std::string> values = t.get_values();
std::transform(values.begin(), values.end(), values.begin(), escape_csv);
fprintf(fout, "%s\n", join(values, ",").c_str()); fprintf(fout, "%s\n", join(values, ",").c_str());
} }
}; };
@ -611,10 +642,32 @@ struct csv_printer : public printer {
struct json_printer : public printer { struct json_printer : public printer {
bool first = true; bool first = true;
void print_fields(const std::vector<std::string> & fields, const std::vector<std::string> & values) { static std::string escape_json(const std::string & value) {
assert(fields.size() == values.size()); std::string escaped;
for (size_t i = 0; i < fields.size(); i++) { for (auto c : value) {
fprintf(fout, " \"%s\": \"%s\",\n", fields.at(i).c_str(), values.at(i).c_str()); if (c == '"') {
escaped += "\\\"";
} else if (c == '\\') {
escaped += "\\\\";
} else if (c <= 0x1f) {
char buf[8];
snprintf(buf, sizeof(buf), "\\u%04x", c);
escaped += buf;
} else {
escaped += c;
}
}
return escaped;
}
static std::string format_value(const std::string & field, const std::string & value) {
switch (test::get_field_type(field)) {
case test::STRING:
return "\"" + escape_json(value) + "\"";
case test::BOOL:
return value == "0" ? "false" : "true";
default:
return value;
} }
} }
@ -623,6 +676,13 @@ struct json_printer : public printer {
(void) params; (void) params;
} }
void print_fields(const std::vector<std::string> & fields, const std::vector<std::string> & values) {
assert(fields.size() == values.size());
for (size_t i = 0; i < fields.size(); i++) {
fprintf(fout, " \"%s\": %s,\n", fields.at(i).c_str(), format_value(fields.at(i), values.at(i)).c_str());
}
}
void print_test(const test & t) override { void print_test(const test & t) override {
if (first) { if (first) {
first = false; first = false;
@ -634,6 +694,7 @@ struct json_printer : public printer {
fprintf(fout, " \"samples_ns\": [ %s ],\n", join(t.samples_ns, ", ").c_str()); fprintf(fout, " \"samples_ns\": [ %s ],\n", join(t.samples_ns, ", ").c_str());
fprintf(fout, " \"samples_ts\": [ %s ]\n", join(t.get_ts(), ", ").c_str()); fprintf(fout, " \"samples_ts\": [ %s ]\n", join(t.get_ts(), ", ").c_str());
fprintf(fout, " }"); fprintf(fout, " }");
fflush(fout);
} }
void print_footer() override { void print_footer() override {
@ -652,7 +713,8 @@ struct markdown_printer : public printer {
return 15; return 15;
} }
int width = std::max((int)field.length(), 10); int width = std::max((int)field.length(), 10);
if (field == "backend") {
if (test::get_field_type(field) == test::STRING) {
return -width; return -width;
} }
return width; return width;
@ -696,7 +758,7 @@ struct markdown_printer : public printer {
fprintf(fout, "|"); fprintf(fout, "|");
for (const auto & field: fields) { for (const auto & field: fields) {
int width = get_field_width(field); int width = get_field_width(field);
fprintf(fout, " %s%s%s |", width < 0 ? ":" : "", std::string(std::abs(width) - 1, '-').c_str(), width > 0 ? ":" : ""); fprintf(fout, " %s%s |", std::string(std::abs(width) - 1, '-').c_str(), width > 0 ? ":" : "-");
} }
fprintf(fout, "\n"); fprintf(fout, "\n");
(void) params; (void) params;
@ -750,30 +812,26 @@ struct markdown_printer : public printer {
}; };
struct sql_printer : public printer { struct sql_printer : public printer {
static std::string get_field_type(const std::string & field) { static std::string get_sql_field_type(const std::string & field) {
if (field == "build_number") { switch (test::get_field_type(field)) {
return "INTEGER"; case test::STRING:
}
if (field == "cuda" || field == "opencl" || field == "metal" || field == "gpu_blas" || field == "blas") {
return "INTEGER";
}
if (field == "n_batch" || field == "n_threads" || field == "f16_kv" || field == "n_gpu_layers" || field == "main_gpu" || field == "mul_mat_q" || field == "low_vram") {
return "INTEGER";
}
if (field == "n_prompt" || field == "n_gen") {
return "INTEGER";
}
if (field == "avg_ns" || field == "stddev_ns" || field == "avg_ts" || field == "stddev_ts") {
return "REAL";
}
return "TEXT"; return "TEXT";
case test::BOOL:
case test::INT:
return "INTEGER";
case test::FLOAT:
return "REAL";
default:
assert(false);
exit(1);
}
} }
void print_header(const cmd_params & params) override { void print_header(const cmd_params & params) override {
std::vector<std::string> fields = test::get_fields(); std::vector<std::string> fields = test::get_fields();
fprintf(fout, "CREATE TABLE IF NOT EXISTS test (\n"); fprintf(fout, "CREATE TABLE IF NOT EXISTS test (\n");
for (size_t i = 0; i < fields.size(); i++) { for (size_t i = 0; i < fields.size(); i++) {
fprintf(fout, " %s %s%s\n", fields.at(i).c_str(), get_field_type(fields.at(i)).c_str(), i < fields.size() - 1 ? "," : ""); fprintf(fout, " %s %s%s\n", fields.at(i).c_str(), get_sql_field_type(fields.at(i)).c_str(), i < fields.size() - 1 ? "," : "");
} }
fprintf(fout, ");\n"); fprintf(fout, ");\n");
fprintf(fout, "\n"); fprintf(fout, "\n");