diff --git a/common/common.cpp b/common/common.cpp index c11006bcb..d9d21ce07 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1,5 +1,6 @@ #include "common.h" #include "llama.h" +#include "json-util.hpp" #include #include @@ -117,6 +118,7 @@ void process_escapes(std::string& input) { bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { bool result = true; + try { if (!gpt_params_parse_ex(argc, argv, params)) { gpt_print_usage(argc, argv, gpt_params()); @@ -128,6 +130,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { gpt_print_usage(argc, argv, gpt_params()); exit(1); } + return result; } @@ -792,7 +795,16 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { // End of Parse args for logging parameters #endif // LOG_DISABLE_LOGS } else { - throw std::invalid_argument("error: unknown argument: " + arg); + args_struct args_obj(argv[i]); + + if (args_obj.argc > 0) { + int argc_json = args_obj.argc; + char** args_json = args_obj.argv; + + return gpt_params_parse(argc_json, args_json, params); + } else { + throw std::invalid_argument("error: unknown argument: " + arg); + } } } if (invalid_param) { diff --git a/common/json-util.hpp b/common/json-util.hpp new file mode 100644 index 000000000..6a5e4f5ef --- /dev/null +++ b/common/json-util.hpp @@ -0,0 +1,110 @@ +#pragma once + +#include +#include +#include +#include "examples/server/json.hpp" + +static nlohmann::json get_json(const char* file_name) noexcept { + try { + printf("Opening a json file %s\n", file_name); + std::ifstream jstream(file_name); + return nlohmann::json::parse(jstream); + } + catch (const std::exception& ex) { + fprintf(stderr, "%s\n", ex.what()); + } + + return {}; +} + +struct args_struct { + size_t argc; + char** argv; + size_t elements_count = 0; + size_t index = 0; + + std::vector arg_chars = {}; + std::vector arg_idxs = {}; + std::vector arg_ptrs = {}; + + args_struct() = default; + + args_struct(char* file_name) { + createParams(file_name); + } + + ~args_struct() = default; + + void reset() noexcept { + elements_count = 0; + index = 0; + + arg_chars.clear(); + arg_idxs.clear(); + + reset_args(); + } + + void reset_args() noexcept { + arg_ptrs.clear(); + argc = 0; + argv = nullptr; + } + + void add(const std::string& data) { + // resetting previous args + reset_args(); + + arg_idxs.emplace_back(index); + for(const auto& character : data) { + arg_chars.emplace_back(character); + ++index; + } + + arg_chars.emplace_back('\0'); + ++index; + ++elements_count; + } + + void get_args() { + reset_args(); + if(elements_count) { + arg_ptrs.reserve(elements_count); + + for(const auto& index : arg_idxs) { + arg_ptrs.emplace_back(&arg_chars[index]); + } + } else { + arg_ptrs.emplace_back(nullptr); + } + + argc = elements_count; + argv = &arg_ptrs[0]; + } + + void createParams(char* file_name) { + reset(); // starting over + nlohmann::json file_config = get_json(file_name); + if (!file_config.empty()) { // ensures no unnecessary work + add(file_name); + for (auto& p : file_config.items()) { + // only use strings, numbers and booleans for switches + if (p.value().is_string() || p.value().is_number() || p.value().is_boolean()) { + add(p.key()); + + if (!p.value().is_boolean()) { + std::string param_value; + if (p.value().is_string()) { + param_value = p.value().get(); + } else if (p.value().is_number()) { + param_value = std::to_string(p.value().get()); // works for int values too + } + add(param_value); + } + } + } + get_args(); + } + } +}; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 93f999298..022e04241 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1,6 +1,7 @@ #include "common.h" #include "llama.h" #include "grammar-parser.h" +#include "json-util.hpp" #include "../llava/clip.h" @@ -2517,9 +2518,18 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } else { - fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); - server_print_usage(argv[0], default_params, default_sparams); - exit(1); + args_struct args_obj(argv[i]); + + if (args_obj.argc > 0) { + int argc_json = args_obj.argc; + char** args_json = args_obj.argv; + + return server_params_parse(argc_json, args_json, sparams, params, llama); + } else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + server_print_usage(argv[0], default_params, default_sparams); + exit(1); + } } } if (!params.kv_overrides.empty()) {