diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index ea06fcdbf..e46083707 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -36,6 +36,8 @@ public: void set_parameters(StatParams&& params) { m_params = std::move(params); } bool collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data); void save_imatrix() const; + bool load_imatrix(const char * file_name, bool add); + static bool load_imatrix(const char * file_name, std::unordered_map& imatrix); private: std::unordered_map m_stats; StatParams m_params; @@ -189,6 +191,57 @@ void IMatrixCollector::save_imatrix(const char * fname) const { } } +bool IMatrixCollector::load_imatrix(const char * imatrix_file, std::unordered_map& imatrix_data) { + std::ifstream in(imatrix_file, std::ios::binary); + if (!in) { + printf("%s: failed to open %s\n",__func__,imatrix_file); + return false; + } + int n_entries; + in.read((char*)&n_entries, sizeof(n_entries)); + if (in.fail() || n_entries < 1) { + printf("%s: no data in file %s\n", __func__, imatrix_file); + return false; + } + for (int i = 0; i < n_entries; ++i) { + int len; in.read((char *)&len, sizeof(len)); + std::vector name_as_vec(len+1); + in.read((char *)name_as_vec.data(), len); + if (in.fail()) { + printf("%s: failed reading name for entry %d from %s\n",__func__,i+1,imatrix_file); + return false; + } + name_as_vec[len] = 0; + std::string name{name_as_vec.data()}; + auto& e = imatrix_data[std::move(name)]; + int ncall; + in.read((char*)&ncall, sizeof(ncall)); + int nval; + in.read((char *)&nval, sizeof(nval)); + if (in.fail() || nval < 1) { + printf("%s: failed reading number of values for entry %d\n",__func__,i); + imatrix_data = {}; + return false; + } + e.values.resize(nval); + in.read((char*)e.values.data(), nval*sizeof(float)); + if (in.fail()) { + printf("%s: failed reading data for entry %d\n",__func__,i); + imatrix_data = {}; + return false; + } + e.ncall = ncall; + } + return true; +} + +bool IMatrixCollector::load_imatrix(const char * file_name, bool add) { + if (!add) { + m_stats.clear(); + } + return load_imatrix(file_name, m_stats); +} + static IMatrixCollector g_collector; static bool ik_collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data) { @@ -402,6 +455,8 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool int main(int argc, char ** argv) { StatParams sparams; + std::string prev_result_file; + std::string combine_files; bool compute_ppl = true; std::vector args; args.push_back(argv[0]); @@ -423,6 +478,10 @@ int main(int argc, char ** argv) { compute_ppl = false; } else if (arg == "--keep-imatrix") { sparams.keep_every = std::stoi(argv[++iarg]); + } else if (arg == "--continue-from") { + prev_result_file = argv[++iarg]; + } else if (arg == "--combine") { + combine_files = argv[++iarg]; } else { args.push_back(argv[iarg]); } @@ -436,14 +495,50 @@ int main(int argc, char ** argv) { } } + g_collector.set_parameters(std::move(sparams)); + + if (!combine_files.empty()) { + std::vector files; + size_t pos = 0; + while (true) { + auto new_pos = combine_files.find(',', pos); + if (new_pos != std::string::npos) { + files.emplace_back(combine_files.substr(pos, new_pos - pos)); + pos = new_pos + 1; + } else { + files.emplace_back(combine_files.substr(pos)); + break; + } + } + if (files.size() < 2) { + fprintf(stderr, "You must provide at least two comma separated files to use --combine\n"); + return 1; + } + printf("Combining the following %d files\n", int(files.size())); + for (auto& file : files) { + printf(" %s\n", file.c_str()); + if (!g_collector.load_imatrix(file.c_str(), true)) { + fprintf(stderr, "Failed to load %s\n", file.c_str()); + return 1; + } + } + g_collector.save_imatrix(); + return 0; + } + + if (!prev_result_file.empty()) { + if (!g_collector.load_imatrix(prev_result_file.c_str(), false)) { + fprintf(stderr, "=============== Failed to load %s\n", prev_result_file.c_str()); + return 1; + } + } + gpt_params params; params.n_batch = 512; if (!gpt_params_parse(args.size(), args.data(), params)) { return 1; } - g_collector.set_parameters(std::move(sparams)); - params.logits_all = true; params.n_batch = std::min(params.n_batch, params.n_ctx);