diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 8ea5148c9..90e74501e 100755 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -537,6 +537,25 @@ struct test { return fields; } + enum field_type {STRING, BOOL, INT, FLOAT}; + + static field_type get_field_type(const std::string & field) { + if (field == "build_number" || field == "n_batch" || field == "n_threads" || + field == "n_gpu_layers" || field == "main_gpu" || + field == "n_prompt" || field == "n_gen" || + field == "avg_ns" || field == "stddev_ns") { + return INT; + } + if (field == "cuda" || field == "opencl" || field == "metal" || field == "gpu_blas" || field == "blas" || + field == "f16_kv" || field == "mul_mat_q" || field == "low_vram") { + return BOOL; + } + if (field == "avg_ts" || field == "stddev_ts") { + return FLOAT; + } + return STRING; + } + std::vector get_values() const { std::string tensor_split_str; int max_nonzero = 0; @@ -594,8 +613,19 @@ struct printer { virtual void print_footer() { }; }; -// TODO: escape strings struct csv_printer : public printer { + static std::string escape_csv(const std::string & field) { + std::string escaped = "\""; + for (auto c : field) { + if (c == '"') { + escaped += "\""; + } + escaped += c; + } + escaped += "\""; + return escaped; + } + void print_header(const cmd_params & params) override { std::vector fields = test::get_fields(); fprintf(fout, "%s\n", join(fields, ",").c_str()); @@ -604,6 +634,7 @@ struct csv_printer : public printer { void print_test(const test & t) override { std::vector values = t.get_values(); + std::transform(values.begin(), values.end(), values.begin(), escape_csv); fprintf(fout, "%s\n", join(values, ",").c_str()); } }; @@ -611,10 +642,32 @@ struct csv_printer : public printer { struct json_printer : public printer { bool first = true; - void print_fields(const std::vector & fields, const std::vector & values) { - assert(fields.size() == values.size()); - for (size_t i = 0; i < fields.size(); i++) { - fprintf(fout, " \"%s\": \"%s\",\n", fields.at(i).c_str(), values.at(i).c_str()); + static std::string escape_json(const std::string & value) { + std::string escaped; + for (auto c : value) { + if (c == '"') { + escaped += "\\\""; + } else if (c == '\\') { + escaped += "\\\\"; + } else if (c <= 0x1f) { + char buf[8]; + snprintf(buf, sizeof(buf), "\\u%04x", c); + escaped += buf; + } else { + escaped += c; + } + } + return escaped; + } + + static std::string format_value(const std::string & field, const std::string & value) { + switch (test::get_field_type(field)) { + case test::STRING: + return "\"" + escape_json(value) + "\""; + case test::BOOL: + return value == "0" ? "false" : "true"; + default: + return value; } } @@ -623,6 +676,13 @@ struct json_printer : public printer { (void) params; } + void print_fields(const std::vector & fields, const std::vector & values) { + assert(fields.size() == values.size()); + for (size_t i = 0; i < fields.size(); i++) { + fprintf(fout, " \"%s\": %s,\n", fields.at(i).c_str(), format_value(fields.at(i), values.at(i)).c_str()); + } + } + void print_test(const test & t) override { if (first) { first = false; @@ -634,6 +694,7 @@ struct json_printer : public printer { fprintf(fout, " \"samples_ns\": [ %s ],\n", join(t.samples_ns, ", ").c_str()); fprintf(fout, " \"samples_ts\": [ %s ]\n", join(t.get_ts(), ", ").c_str()); fprintf(fout, " }"); + fflush(fout); } void print_footer() override { @@ -652,7 +713,8 @@ struct markdown_printer : public printer { return 15; } int width = std::max((int)field.length(), 10); - if (field == "backend") { + + if (test::get_field_type(field) == test::STRING) { return -width; } return width; @@ -696,7 +758,7 @@ struct markdown_printer : public printer { fprintf(fout, "|"); for (const auto & field: fields) { int width = get_field_width(field); - fprintf(fout, " %s%s%s |", width < 0 ? ":" : "", std::string(std::abs(width) - 1, '-').c_str(), width > 0 ? ":" : ""); + fprintf(fout, " %s%s |", std::string(std::abs(width) - 1, '-').c_str(), width > 0 ? ":" : "-"); } fprintf(fout, "\n"); (void) params; @@ -750,30 +812,26 @@ struct markdown_printer : public printer { }; struct sql_printer : public printer { - static std::string get_field_type(const std::string & field) { - if (field == "build_number") { - return "INTEGER"; + static std::string get_sql_field_type(const std::string & field) { + switch (test::get_field_type(field)) { + case test::STRING: + return "TEXT"; + case test::BOOL: + case test::INT: + return "INTEGER"; + case test::FLOAT: + return "REAL"; + default: + assert(false); + exit(1); } - if (field == "cuda" || field == "opencl" || field == "metal" || field == "gpu_blas" || field == "blas") { - return "INTEGER"; - } - if (field == "n_batch" || field == "n_threads" || field == "f16_kv" || field == "n_gpu_layers" || field == "main_gpu" || field == "mul_mat_q" || field == "low_vram") { - return "INTEGER"; - } - if (field == "n_prompt" || field == "n_gen") { - return "INTEGER"; - } - if (field == "avg_ns" || field == "stddev_ns" || field == "avg_ts" || field == "stddev_ts") { - return "REAL"; - } - return "TEXT"; } void print_header(const cmd_params & params) override { std::vector fields = test::get_fields(); fprintf(fout, "CREATE TABLE IF NOT EXISTS test (\n"); for (size_t i = 0; i < fields.size(); i++) { - fprintf(fout, " %s %s%s\n", fields.at(i).c_str(), get_field_type(fields.at(i)).c_str(), i < fields.size() - 1 ? "," : ""); + fprintf(fout, " %s %s%s\n", fields.at(i).c_str(), get_sql_field_type(fields.at(i)).c_str(), i < fields.size() - 1 ? "," : ""); } fprintf(fout, ");\n"); fprintf(fout, "\n");