diff --git a/common/common.cpp b/common/common.cpp index 03f35de16..6aac39aad 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -128,64 +128,148 @@ nlohmann::json get_json(const char* file_name) noexcept { return {}; } -bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { - bool result = true; +std::vector args_parse_json_only(char* file_name) { std::vector arguments_w_json; - arguments_w_json.push_back(argv[0]); - int pos = 1; - // only the second argument to reduce reading attempts (plus drag'n'drop) - if (argc > 1) { - // console arguments should override json values, so json processing goes first - nlohmann::json file_config = get_json(argv[1]); - pos = 2; // avoid putting file name into arguments - if (!file_config.empty()) { + nlohmann::json file_config = get_json(file_name); + if (!file_config.empty()) { // ensures no unnecessary work + arguments_w_json.push_back(file_name); - for (auto& p : file_config.items()) { - // only use strings, numbers and booleans for switches - if (p.value().is_string() || p.value().is_number() || p.value().is_boolean()) { - char* key = new char[p.key().length() + 1]; - strcpy(key, p.key().c_str()); - arguments_w_json.push_back(key); + for (auto& p : file_config.items()) { + // only use strings, numbers and booleans for switches + if (p.value().is_string() || p.value().is_number() || p.value().is_boolean()) { + char* key = new char[p.key().length() + 1]; + strcpy(key, p.key().c_str()); + arguments_w_json.push_back(key); - if (!p.value().is_boolean()) { - std::string param_value; - if (p.value().is_string()) { - param_value = p.value().get(); - } else if (p.value().is_number()) { - param_value = std::to_string(p.value().get()); // nlohmann::json can't just get numbers as strings, float works fine for int values - } - - char* val = new char[param_value.length() + 1]; - strcpy(val, param_value.c_str()); - arguments_w_json.push_back(val); + if (!p.value().is_boolean()) { + std::string param_value; + if (p.value().is_string()) { + param_value = p.value().get(); + } else if (p.value().is_number()) { + param_value = std::to_string(p.value().get()); // nlohmann::json can't just get numbers as strings, float works fine for int values } + char* val = new char[param_value.length() + 1]; + strcpy(val, param_value.c_str()); + arguments_w_json.push_back(val); } } } + } + + return arguments_w_json; +} - for (int i = pos; i < argc; i++) { - arguments_w_json.push_back(argv[i]); +std::vector args_parse_json_only_string(char* file_name) { + std::vector arguments_w_json; + nlohmann::json file_config = get_json(file_name); + if (!file_config.empty()) { // ensures no unnecessary work + arguments_w_json.push_back(file_name); + + for (auto& p : file_config.items()) { + // only use strings, numbers and booleans for switches + if (p.value().is_string() || p.value().is_number() || p.value().is_boolean()) { + arguments_w_json.push_back(p.key()); + + if (!p.value().is_boolean()) { + std::string param_value; + if (p.value().is_string()) { + param_value = p.value().get(); + } else if (p.value().is_number()) { + param_value = std::to_string(p.value().get()); // nlohmann::json can't just get numbers as strings, float works fine for int values + } + arguments_w_json.push_back(param_value); + } + } } } + + return arguments_w_json; +} - int argc_json = arguments_w_json.size(); - char** argv_json = &arguments_w_json[0]; +// standalone parsing attempt +bool gpt_params_parse_json(char* file_name, gpt_params & params) { + bool result = true; + //std::vector arguments = args_parse_json_only(file_name); + std::vector arguments = args_parse_json_only_string(file_name); + + if (!arguments.empty()) { // ensures no unnecessary work + int argc_json = arguments.size(); + //char** args_json = const_cast(&arguments[0]); + char** args_json = new char*[arguments.size()]; + for(size_t i = 0; i < arguments.size(); i++) { + args_json[i] = new char[arguments[i].size() + 1]; + strcpy(args_json[i], arguments[i].c_str()); + } + + try { + if (!gpt_params_parse_ex(argc_json, args_json, params)) { + gpt_print_usage(argc_json, args_json, gpt_params()); + exit(0); + } + for (size_t i = 0; i < arguments.size(); i++) { + delete [] args_json[i]; + } + delete [] args_json; + } + catch (const std::invalid_argument & ex) { + fprintf(stderr, "%s\n", ex.what()); + gpt_print_usage(argc_json, args_json, gpt_params()); + exit(1); + } + } else { + // let's also print help, pointing at a faulty file name/parameter + char** args = new char* {file_name}; + gpt_print_usage(1, args, gpt_params()); + exit(1); + } + + return result; +} + +// still deciding which one is worse, both leak +bool gpt_params_parse_json0(char* file_name, gpt_params & params) { + bool result = true; + std::vector arguments = args_parse_json_only(file_name); + + if (!arguments.empty()) { // ensures no unnecessary work + int argc_json = arguments.size(); + char** args_json = &arguments[0]; + + try { + if (!gpt_params_parse_ex(argc_json, args_json, params)) { + gpt_print_usage(argc_json, args_json, gpt_params()); + exit(0); + } + } + catch (const std::invalid_argument & ex) { + fprintf(stderr, "%s\n", ex.what()); + gpt_print_usage(argc_json, args_json, gpt_params()); + exit(1); + } + } else { + // let's also print help, pointing at a faulty file name/parameter + char** args = new char* {file_name}; + gpt_print_usage(1, args, gpt_params()); + exit(1); + } + + return result; +} + +bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { + bool result = true; try { - if (!gpt_params_parse_ex(argc_json, argv_json, params)) { - gpt_print_usage(argc_json, argv_json, gpt_params()); + if (!gpt_params_parse_ex(argc, argv, params)) { + gpt_print_usage(argc, argv, gpt_params()); exit(0); } } catch (const std::invalid_argument & ex) { fprintf(stderr, "%s\n", ex.what()); - gpt_print_usage(argc_json, argv_json, gpt_params()); + gpt_print_usage(argc, argv, gpt_params()); exit(1); } - // clearing pointers - for (auto c : arguments_w_json) { - delete [] c; - } return result; } @@ -814,7 +898,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { // End of Parse args for logging parameters #endif // LOG_DISABLE_LOGS } else { - throw std::invalid_argument("error: unknown argument: " + arg); + if (!gpt_params_parse_json0(argv[i], params)) { // attempt to read as a file + invalid_param = true; + throw std::invalid_argument("error: unknown argument: " + arg); + } } } if (invalid_param) { diff --git a/common/common.h b/common/common.h index d8c14467b..ad38142df 100644 --- a/common/common.h +++ b/common/common.h @@ -138,6 +138,14 @@ struct gpt_params { nlohmann::json get_json(const char* file_name) noexcept; +std::vector args_parse_json_only(char* file_name); + +std::vector args_parse_json_only_string(char* file_name); + +bool gpt_params_parse_json(char* file_name, gpt_params & params); + +bool gpt_params_parse_json0(char* file_name, gpt_params & params); + bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params); bool gpt_params_parse(int argc, char ** argv, gpt_params & params);