From 8701b58276ef7b5e41fb44caf4c2b2f5a1e72951 Mon Sep 17 00:00:00 2001 From: MaggotHATE Date: Wed, 6 Dec 2023 17:13:13 +0500 Subject: [PATCH] Json integration into common (parameters parsing) --- common/common.cpp | 68 ++++++++++++++++++++++++++++++++++++++++++++--- common/common.h | 4 +++ 2 files changed, 69 insertions(+), 3 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 4e823c526..99319471f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -115,19 +115,81 @@ void process_escapes(std::string& input) { input.resize(output_idx); } +nlohmann::json get_json(std::string file_name) { + // safeguard since we expose this fucntion in header + if (file_name.find(".json") == file_name.npos) file_name += ".json"; + nlohmann::json config; + std::fstream jstream(file_name); + if (jstream.is_open()) { + try { + config = nlohmann::json::parse(jstream); + jstream.close(); + printf("Opened a json file %s\n", file_name.c_str()); + } + catch (nlohmann::json::parse_error& ex) { + jstream.close(); + fprintf(stderr, "%s\n", ex.what()); + return config; + } + } else { + printf("%s not found!\n", file_name.c_str()); + } + + return config; +} + bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { bool result = true; + std::vector arguments_w_json; + arguments_w_json.push_back(argv[0]); + int pos = 1; + + if (argc > 1) { + std::string json_name = argv[1]; + // console arguments should override json values, so json processing goes first + // to avoid exta work, let's expect at least an extention + if (json_name.rfind(".json") != json_name.npos) { + nlohmann::json file_config = get_json(argv[1]); + // avoid putting file name in arguments + pos = 2; + for (auto& p : file_config.items()) { + // only use strings, numbers and booleans for switches + if (p.value().is_string() || p.value().is_number() || p.value().is_boolean()) { + char* key = new char[p.key().length() + 1]; + strcpy(key, p.key().c_str()); + arguments_w_json.push_back(key); + std::string param_value; + if (!p.value().is_boolean()){ + if (p.value().is_string()) param_value = p.value().get(); + else if (p.value().is_number()) param_value = std::to_string(p.value().get()); + + char* val = new char[param_value.length() + 1]; + strcpy(val, param_value.c_str()); + arguments_w_json.push_back(val); + } + } + } + } + for (int i = pos; i < argc; i++) { + arguments_w_json.push_back(argv[i]); + } + } + + int argc_json = arguments_w_json.size(); + char** argv_json = &arguments_w_json[0]; + try { - if (!gpt_params_parse_ex(argc, argv, params)) { - gpt_print_usage(argc, argv, gpt_params()); + if (!gpt_params_parse_ex(argc_json, argv_json, params)) { + gpt_print_usage(argc_json, argv_json, gpt_params()); exit(0); } } catch (const std::invalid_argument & ex) { fprintf(stderr, "%s\n", ex.what()); - gpt_print_usage(argc, argv, gpt_params()); + gpt_print_usage(argc_json, argv_json, gpt_params()); exit(1); } + return result; } diff --git a/common/common.h b/common/common.h index 024679380..5bb0db407 100644 --- a/common/common.h +++ b/common/common.h @@ -6,6 +6,8 @@ #include "sampling.h" +#include "examples/server/json.hpp" + #define LOG_NO_FILE_LINE_FUNCTION #include "log.h" @@ -131,6 +133,8 @@ struct gpt_params { std::string image = ""; // path to an image file }; +nlohmann::json get_json(std::string file_name); + bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params); bool gpt_params_parse(int argc, char ** argv, gpt_params & params);