merge: try..catch

This commit is contained in:
ngxson 2024-03-01 14:50:42 +01:00
parent 3e6e3668c9
commit 2cfae6d9a8

View file

@ -122,34 +122,39 @@ int main(int argc, char ** argv) {
} }
} }
if (invalid_param) { try {
throw std::invalid_argument("error: invalid parameter for argument: " + arg); if (invalid_param) {
} else if (config_path.empty()) { throw std::invalid_argument("error: invalid parameter for argument: " + arg);
throw std::invalid_argument("error: missing config path"); } else if (config_path.empty()) {
} else if (model_paths.size() < 2) { throw std::invalid_argument("error: missing config path");
throw std::invalid_argument("error: require at least 2 models"); } else if (model_paths.size() < 2) {
} else if (output_path.empty()) { throw std::invalid_argument("error: require at least 2 models");
throw std::invalid_argument("error: missing output path"); } else if (output_path.empty()) {
throw std::invalid_argument("error: missing output path");
}
// buffers to hold allocated data
std::vector<int> buf_srcs;
std::vector<float> buf_scales;
auto layers = parse_config(config_path, model_paths.size(), buf_srcs, buf_scales);
std::vector<const char*> p_model_paths;
for (auto & m : model_paths) {
p_model_paths.push_back(m.data());
}
const struct llama_merge_config config{
p_model_paths.data(),
p_model_paths.size(),
layers.data(),
layers.size(),
output_path.data(),
};
llama_merge_models(&config);
} catch (const std::exception & ex) {
std::cerr << ex.what() << "\n\n";
usage(argv[0], 1);
} }
// buffers to hold allocated data
std::vector<int> buf_srcs;
std::vector<float> buf_scales;
auto layers = parse_config(config_path, model_paths.size(), buf_srcs, buf_scales);
std::vector<const char*> p_model_paths;
for (auto & m : model_paths) {
p_model_paths.push_back(m.data());
}
const struct llama_merge_config config{
p_model_paths.data(),
p_model_paths.size(),
layers.data(),
layers.size(),
output_path.data(),
};
llama_merge_models(&config);
return 0; return 0;
} }