This commit is contained in:
slaren 2023-08-17 03:22:19 +02:00
parent 94218e8ade
commit 9c3866099b

View file

@ -110,7 +110,7 @@ static std::string get_cpu_info() {
return id; return id;
} }
static std::string get_gpu_info(void) { static std::string get_gpu_info() {
std::string id; std::string id;
#ifdef GGML_USE_CUBLAS #ifdef GGML_USE_CUBLAS
int count = ggml_cuda_get_device_count(); int count = ggml_cuda_get_device_count();
@ -589,25 +589,22 @@ const std::string test::gpu_info = get_gpu_info();
struct printer { struct printer {
FILE * fout; FILE * fout;
virtual void print_header(const cmd_params & params) { (void)params; }; virtual void print_header(const cmd_params & params) { (void) params; };
virtual void print_test(const test & t) = 0; virtual void print_test(const test & t) = 0;
virtual void print_footer() {}; virtual void print_footer() { };
}; };
// TODO: escape strings // TODO: escape strings
struct csv_printer : public printer { struct csv_printer : public printer {
virtual void print_header(const cmd_params & params) { void print_header(const cmd_params & params) override {
std::vector<std::string> fields = test::get_fields(); std::vector<std::string> fields = test::get_fields();
fprintf(fout, "%s\n", join(fields, ",").c_str()); fprintf(fout, "%s\n", join(fields, ",").c_str());
(void) params; (void) params;
} }
void print_values(const std::vector<std::string> & values) { void print_test(const test & t) override {
fprintf(fout, "%s", join(values, ",").c_str()); std::vector<std::string> values = t.get_values();
} fprintf(fout, "%s\n", join(values, ",").c_str());
virtual void print_test(const test & t) {
print_values(t.get_values());
} }
}; };
@ -621,16 +618,12 @@ struct json_printer : public printer {
} }
} }
virtual void print_header(const cmd_params & params) { void print_header(const cmd_params & params) override {
fprintf(fout, "[\n"); fprintf(fout, "[\n");
(void) params; (void) params;
} }
virtual void print_footer() { void print_test(const test & t) override {
fprintf(fout, "\n]\n");
}
virtual void print_test(const test & t) {
if (first) { if (first) {
first = false; first = false;
} else { } else {
@ -642,6 +635,10 @@ struct json_printer : public printer {
fprintf(fout, " \"samples_ts\": [ %s ]\n", join(t.get_ts(), ", ").c_str()); fprintf(fout, " \"samples_ts\": [ %s ]\n", join(t.get_ts(), ", ").c_str());
fprintf(fout, " }"); fprintf(fout, " }");
} }
void print_footer() override {
fprintf(fout, "\n]\n");
}
}; };
struct markdown_printer : public printer { struct markdown_printer : public printer {
@ -661,7 +658,7 @@ struct markdown_printer : public printer {
return width; return width;
} }
virtual void print_header(const cmd_params & params) { void print_header(const cmd_params & params) override {
fields = { "model", "backend" }; fields = { "model", "backend" };
bool is_cpu_backend = test::get_backend() == "CPU" || test::get_backend() == "BLAS"; bool is_cpu_backend = test::get_backend() == "CPU" || test::get_backend() == "BLAS";
if (!is_cpu_backend) { if (!is_cpu_backend) {
@ -705,7 +702,7 @@ struct markdown_printer : public printer {
(void) params; (void) params;
} }
virtual void print_test(const test & t) { void print_test(const test & t) override {
std::map<std::string, std::string> vmap = t.get_map(); std::map<std::string, std::string> vmap = t.get_map();
fprintf(fout, "|"); fprintf(fout, "|");
@ -747,7 +744,7 @@ struct markdown_printer : public printer {
fprintf(fout, "\n"); fprintf(fout, "\n");
} }
virtual void print_footer() { void print_footer() override {
fprintf(fout, "\nbuild: %s (%d)\n", test::build_commit.c_str(), test::build_number); fprintf(fout, "\nbuild: %s (%d)\n", test::build_commit.c_str(), test::build_number);
} }
}; };
@ -772,7 +769,7 @@ struct sql_printer : public printer {
return "TEXT"; return "TEXT";
} }
virtual void print_header(const cmd_params & params) { void print_header(const cmd_params & params) override {
std::vector<std::string> fields = test::get_fields(); std::vector<std::string> fields = test::get_fields();
fprintf(fout, "CREATE TABLE IF NOT EXISTS test (\n"); fprintf(fout, "CREATE TABLE IF NOT EXISTS test (\n");
for (size_t i = 0; i < fields.size(); i++) { for (size_t i = 0; i < fields.size(); i++) {
@ -783,7 +780,7 @@ struct sql_printer : public printer {
(void) params; (void) params;
} }
virtual void print_test(const test & t) { void print_test(const test & t) override {
fprintf(fout, "INSERT INTO test (%s) ", join(test::get_fields(), ", ").c_str()); fprintf(fout, "INSERT INTO test (%s) ", join(test::get_fields(), ", ").c_str());
fprintf(fout, "VALUES ("); fprintf(fout, "VALUES (");
std::vector<std::string> values = t.get_values(); std::vector<std::string> values = t.get_values();
@ -812,9 +809,9 @@ static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads)
} }
static void llama_null_log_callback(enum llama_log_level level, const char * text, void * user_data) { static void llama_null_log_callback(enum llama_log_level level, const char * text, void * user_data) {
(void)level; (void) level;
(void)text; (void) text;
(void)user_data; (void) user_data;
} }
int main(int argc, char ** argv) { int main(int argc, char ** argv) {