diff --git a/examples/common.cpp b/examples/common.cpp index 27f39ca1a..cd6300041 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -66,38 +66,33 @@ int32_t get_num_physical_cores() { return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4; } -std::string process_escapes(const char* input) { - std::string output; +void process_escapes(std::string& input) { + std::size_t input_len = input.length(); + std::size_t output_idx = 0; - if (input != nullptr) { - std::size_t input_len = std::strlen(input); - output.reserve(input_len); - - for (std::size_t i = 0; i < input_len; ++i) { - if (input[i] == '\\' && i + 1 < input_len) { - switch (input[++i]) { - case 'n': output.push_back('\n'); break; - case 'r': output.push_back('\r'); break; - case 't': output.push_back('\t'); break; - case '\'': output.push_back('\''); break; - case '\"': output.push_back('\"'); break; - case '\\': output.push_back('\\'); break; - default: output.push_back('\\'); - output.push_back(input[i]); break; - } - } else { - output.push_back(input[i]); + for (std::size_t input_idx = 0; input_idx < input_len; ++input_idx) { + if (input[input_idx] == '\\' && input_idx + 1 < input_len) { + switch (input[++input_idx]) { + case 'n': input[output_idx++] = '\n'; break; + case 'r': input[output_idx++] = '\r'; break; + case 't': input[output_idx++] = '\t'; break; + case '\'': input[output_idx++] = '\''; break; + case '\"': input[output_idx++] = '\"'; break; + case '\\': input[output_idx++] = '\\'; break; + default: input[output_idx++] = '\\'; + input[output_idx++] = input[input_idx]; break; } + } else { + input[output_idx++] = input[input_idx]; } } - return output; + input.resize(output_idx); } bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { bool invalid_param = false; bool escape_prompt = false; - const char* prompt = nullptr; std::string arg; gpt_params default_params; @@ -121,7 +116,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { invalid_param = true; break; } - prompt = argv[i]; + params.prompt = argv[i]; } else if (arg == "-e") { escape_prompt = true; } else if (arg == "--session") { @@ -340,7 +335,9 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { gpt_print_usage(argc, argv, default_params); exit(1); } - params.prompt = escape_prompt ? process_escapes(prompt) : prompt; + if (escape_prompt) { + process_escapes(params.prompt); + } return true; }