From 0773028d52679f67414f5cfa08366bc86ef17fc4 Mon Sep 17 00:00:00 2001 From: Liu Ming Date: Mon, 29 May 2023 14:07:13 +0800 Subject: [PATCH] 1) make gpt_params_parse can jump over some predefined unknown args so we can reuse the gpt_params_parse function 2) fixed the grpc server error --- examples/common.cpp | 754 +++++++++++++++++++-------- examples/common.h | 1 + examples/grpc-server/grpc-server.cpp | 103 ++-- 3 files changed, 569 insertions(+), 289 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 1308f8410..24dcb30b7 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -27,59 +27,85 @@ #include #endif -int32_t get_num_physical_cores() { +int32_t get_num_physical_cores() +{ #ifdef __linux__ // enumerate the set of thread siblings, num entries is num cores std::unordered_set siblings; - for (uint32_t cpu=0; cpu < UINT32_MAX; ++cpu) { - std::ifstream thread_siblings("/sys/devices/system/cpu" - + std::to_string(cpu) + "/topology/thread_siblings"); - if (!thread_siblings.is_open()) { + for (uint32_t cpu = 0; cpu < UINT32_MAX; ++cpu) + { + std::ifstream thread_siblings("/sys/devices/system/cpu" + std::to_string(cpu) + "/topology/thread_siblings"); + if (!thread_siblings.is_open()) + { break; // no more cpus } std::string line; - if (std::getline(thread_siblings, line)) { + if (std::getline(thread_siblings, line)) + { siblings.insert(line); } } - if (siblings.size() > 0) { + if (siblings.size() > 0) + { return static_cast(siblings.size()); } #elif defined(__APPLE__) && defined(__MACH__) int32_t num_physical_cores; size_t len = sizeof(num_physical_cores); int result = sysctlbyname("hw.perflevel0.physicalcpu", &num_physical_cores, &len, NULL, 0); - if (result == 0) { + if (result == 0) + { return num_physical_cores; } result = sysctlbyname("hw.physicalcpu", &num_physical_cores, &len, NULL, 0); - if (result == 0) { + if (result == 0) + { return num_physical_cores; } #elif defined(_WIN32) - //TODO: Implement + // TODO: Implement #endif unsigned int n_threads = std::thread::hardware_concurrency(); return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4; } -void process_escapes(std::string& input) { +void process_escapes(std::string &input) +{ std::size_t input_len = input.length(); std::size_t output_idx = 0; - for (std::size_t input_idx = 0; input_idx < input_len; ++input_idx) { - if (input[input_idx] == '\\' && input_idx + 1 < input_len) { - switch (input[++input_idx]) { - case 'n': input[output_idx++] = '\n'; break; - case 'r': input[output_idx++] = '\r'; break; - case 't': input[output_idx++] = '\t'; break; - case '\'': input[output_idx++] = '\''; break; - case '\"': input[output_idx++] = '\"'; break; - case '\\': input[output_idx++] = '\\'; break; - default: input[output_idx++] = '\\'; - input[output_idx++] = input[input_idx]; break; + for (std::size_t input_idx = 0; input_idx < input_len; ++input_idx) + { + if (input[input_idx] == '\\' && input_idx + 1 < input_len) + { + switch (input[++input_idx]) + { + case 'n': + input[output_idx++] = '\n'; + break; + case 'r': + input[output_idx++] = '\r'; + break; + case 't': + input[output_idx++] = '\t'; + break; + case '\'': + input[output_idx++] = '\''; + break; + case '\"': + input[output_idx++] = '\"'; + break; + case '\\': + input[output_idx++] = '\\'; + break; + default: + input[output_idx++] = '\\'; + input[output_idx++] = input[input_idx]; + break; } - } else { + } + else + { input[output_idx++] = input[input_idx]; } } @@ -87,223 +113,360 @@ void process_escapes(std::string& input) { input.resize(output_idx); } -bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { +bool gpt_params_parse(int argc, char **argv, gpt_params ¶ms) +{ + return gpt_params_parse_with_extra_check(argc, argv, params, nullptr); +} + +bool check_is_extra_args(std::string arg, std::vector *extra_args) +{ + if (extra_args != nullptr) + { + for (auto &extra_arg : *extra_args) + { + if (extra_arg == arg) + { + return true; + } + } + } + return false; +} + +bool gpt_params_parse_with_extra_check(int argc, char **argv, gpt_params ¶ms, std::vector *extra_args) +{ bool invalid_param = false; bool escape_prompt = false; std::string arg; gpt_params default_params; const std::string arg_prefix = "--"; - for (int i = 1; i < argc; i++) { + for (int i = 1; i < argc; i++) + { arg = argv[i]; - if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { + if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) + { std::replace(arg.begin(), arg.end(), '_', '-'); } - if (arg == "-s" || arg == "--seed") { + if (arg == "-s" || arg == "--seed") + { #if defined(GGML_USE_CUBLAS) fprintf(stderr, "WARNING: when using cuBLAS generation results are NOT guaranteed to be reproducible.\n"); #endif - if (++i >= argc) { + if (++i >= argc) + { invalid_param = true; break; } params.seed = std::stoi(argv[i]); - } else if (arg == "-t" || arg == "--threads") { - if (++i >= argc) { + } + else if (arg == "-t" || arg == "--threads") + { + if (++i >= argc) + { invalid_param = true; break; } params.n_threads = std::stoi(argv[i]); - } else if (arg == "-p" || arg == "--prompt") { - if (++i >= argc) { + } + else if (arg == "-p" || arg == "--prompt") + { + if (++i >= argc) + { invalid_param = true; break; } params.prompt = argv[i]; - } else if (arg == "-e") { + } + else if (arg == "-e") + { escape_prompt = true; - } else if (arg == "--prompt-cache") { - if (++i >= argc) { + } + else if (arg == "--prompt-cache") + { + if (++i >= argc) + { invalid_param = true; break; } params.path_prompt_cache = argv[i]; - } else if (arg == "--prompt-cache-all") { + } + else if (arg == "--prompt-cache-all") + { params.prompt_cache_all = true; - } else if (arg == "-f" || arg == "--file") { - if (++i >= argc) { + } + else if (arg == "-f" || arg == "--file") + { + if (++i >= argc) + { invalid_param = true; break; } std::ifstream file(argv[i]); - if (!file) { + if (!file) + { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); invalid_param = true; break; } std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.prompt)); - if (params.prompt.back() == '\n') { + if (params.prompt.back() == '\n') + { params.prompt.pop_back(); } - } else if (arg == "-n" || arg == "--n-predict") { - if (++i >= argc) { + } + else if (arg == "-n" || arg == "--n-predict") + { + if (++i >= argc) + { invalid_param = true; break; } params.n_predict = std::stoi(argv[i]); - } else if (arg == "--top-k") { - if (++i >= argc) { + } + else if (arg == "--top-k") + { + if (++i >= argc) + { invalid_param = true; break; } params.top_k = std::stoi(argv[i]); - } else if (arg == "-c" || arg == "--ctx-size") { - if (++i >= argc) { + } + else if (arg == "-c" || arg == "--ctx-size") + { + if (++i >= argc) + { invalid_param = true; break; } params.n_ctx = std::stoi(argv[i]); - } else if (arg == "--memory-f32") { + } + else if (arg == "--memory-f32") + { params.memory_f16 = false; - } else if (arg == "--top-p") { - if (++i >= argc) { + } + else if (arg == "--top-p") + { + if (++i >= argc) + { invalid_param = true; break; } params.top_p = std::stof(argv[i]); - } else if (arg == "--temp") { - if (++i >= argc) { + } + else if (arg == "--temp") + { + if (++i >= argc) + { invalid_param = true; break; } params.temp = std::stof(argv[i]); - } else if (arg == "--tfs") { - if (++i >= argc) { + } + else if (arg == "--tfs") + { + if (++i >= argc) + { invalid_param = true; break; } params.tfs_z = std::stof(argv[i]); - } else if (arg == "--typical") { - if (++i >= argc) { + } + else if (arg == "--typical") + { + if (++i >= argc) + { invalid_param = true; break; } params.typical_p = std::stof(argv[i]); - } else if (arg == "--repeat-last-n") { - if (++i >= argc) { + } + else if (arg == "--repeat-last-n") + { + if (++i >= argc) + { invalid_param = true; break; } params.repeat_last_n = std::stoi(argv[i]); - } else if (arg == "--repeat-penalty") { - if (++i >= argc) { + } + else if (arg == "--repeat-penalty") + { + if (++i >= argc) + { invalid_param = true; break; } params.repeat_penalty = std::stof(argv[i]); - } else if (arg == "--frequency-penalty") { - if (++i >= argc) { + } + else if (arg == "--frequency-penalty") + { + if (++i >= argc) + { invalid_param = true; break; } params.frequency_penalty = std::stof(argv[i]); - } else if (arg == "--presence-penalty") { - if (++i >= argc) { + } + else if (arg == "--presence-penalty") + { + if (++i >= argc) + { invalid_param = true; break; } params.presence_penalty = std::stof(argv[i]); - } else if (arg == "--mirostat") { - if (++i >= argc) { + } + else if (arg == "--mirostat") + { + if (++i >= argc) + { invalid_param = true; break; } params.mirostat = std::stoi(argv[i]); - } else if (arg == "--mirostat-lr") { - if (++i >= argc) { + } + else if (arg == "--mirostat-lr") + { + if (++i >= argc) + { invalid_param = true; break; } params.mirostat_eta = std::stof(argv[i]); - } else if (arg == "--mirostat-ent") { - if (++i >= argc) { + } + else if (arg == "--mirostat-ent") + { + if (++i >= argc) + { invalid_param = true; break; } params.mirostat_tau = std::stof(argv[i]); - } else if (arg == "-b" || arg == "--batch-size") { - if (++i >= argc) { + } + else if (arg == "-b" || arg == "--batch-size") + { + if (++i >= argc) + { invalid_param = true; break; } params.n_batch = std::stoi(argv[i]); params.n_batch = std::min(512, params.n_batch); - } else if (arg == "--keep") { - if (++i >= argc) { + } + else if (arg == "--keep") + { + if (++i >= argc) + { invalid_param = true; break; } params.n_keep = std::stoi(argv[i]); - } else if (arg == "-m" || arg == "--model") { - if (++i >= argc) { + } + else if (arg == "-m" || arg == "--model") + { + if (++i >= argc) + { invalid_param = true; break; } params.model = argv[i]; - } else if (arg == "--lora") { - if (++i >= argc) { + } + else if (arg == "--lora") + { + if (++i >= argc) + { invalid_param = true; break; } params.lora_adapter = argv[i]; params.use_mmap = false; - } else if (arg == "--lora-base") { - if (++i >= argc) { + } + else if (arg == "--lora-base") + { + if (++i >= argc) + { invalid_param = true; break; } params.lora_base = argv[i]; - } else if (arg == "-i" || arg == "--interactive") { + } + else if (arg == "-i" || arg == "--interactive") + { params.interactive = true; - } else if (arg == "--embedding") { + } + else if (arg == "--embedding") + { params.embedding = true; - } else if (arg == "--interactive-first") { + } + else if (arg == "--interactive-first") + { params.interactive_first = true; - } else if (arg == "-ins" || arg == "--instruct") { + } + else if (arg == "-ins" || arg == "--instruct") + { params.instruct = true; - } else if (arg == "--multiline-input") { + } + else if (arg == "--multiline-input") + { params.multiline_input = true; - } else if (arg == "--color") { + } + else if (arg == "--color") + { params.use_color = true; - } else if (arg == "--mlock") { + } + else if (arg == "--mlock") + { params.use_mlock = true; - } else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") { - if (++i >= argc) { + } + else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") + { + if (++i >= argc) + { invalid_param = true; break; } params.n_gpu_layers = std::stoi(argv[i]); - } else if (arg == "--no-mmap") { + } + else if (arg == "--no-mmap") + { params.use_mmap = false; - } else if (arg == "--mtest") { + } + else if (arg == "--mtest") + { params.mem_test = true; - } else if (arg == "--verbose-prompt") { + } + else if (arg == "--verbose-prompt") + { params.verbose_prompt = true; - } else if (arg == "-r" || arg == "--reverse-prompt") { - if (++i >= argc) { + } + else if (arg == "-r" || arg == "--reverse-prompt") + { + if (++i >= argc) + { invalid_param = true; break; } params.antiprompt.push_back(argv[i]); - } else if (arg == "--perplexity") { + } + else if (arg == "--perplexity") + { params.perplexity = true; - } else if (arg == "--ignore-eos") { + } + else if (arg == "--ignore-eos") + { params.logit_bias[llama_token_eos()] = -INFINITY; - } else if (arg == "--no-penalize-nl") { + } + else if (arg == "--no-penalize-nl") + { params.penalize_nl = false; - } else if (arg == "-l" || arg == "--logit-bias") { - if (++i >= argc) { + } + else if (arg == "-l" || arg == "--logit-bias") + { + if (++i >= argc) + { invalid_param = true; break; } @@ -311,59 +474,102 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { llama_token key; char sign; std::string value_str; - try { - if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) { + try + { + if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) + { params.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); - } else { + } + else + { throw std::exception(); } - } catch (const std::exception &e) { + } + catch (const std::exception &e) + { invalid_param = true; break; } - } else if (arg == "-h" || arg == "--help") { + } + else if (arg == "-h" || arg == "--help") + { gpt_print_usage(argc, argv, default_params); exit(0); - } else if (arg == "--random-prompt") { + } + else if (arg == "--random-prompt") + { params.random_prompt = true; - } else if (arg == "--in-prefix") { - if (++i >= argc) { + } + else if (arg == "--in-prefix") + { + if (++i >= argc) + { invalid_param = true; break; } params.input_prefix = argv[i]; - } else if (arg == "--in-suffix") { - if (++i >= argc) { + } + else if (arg == "--in-suffix") + { + if (++i >= argc) + { invalid_param = true; break; } params.input_suffix = argv[i]; - } else { + } + else if (extra_args != nullptr) + { + if (check_is_extra_args(argv[i], extra_args) == true) + { + if (i+1 < argc) + { + std::string content = argv[i+1]; + if (content.compare(0, arg_prefix.size(), arg_prefix) != 0) + { + // is content for extra_arg + i++; + } + } + } + else + { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + gpt_print_usage(argc, argv, default_params); + exit(1); + } + } + else if (extra_args == nullptr || check_is_extra_args(argv[i], extra_args) == false) + { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); gpt_print_usage(argc, argv, default_params); exit(1); } } - if (invalid_param) { + if (invalid_param) + { fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str()); gpt_print_usage(argc, argv, default_params); exit(1); } if (params.prompt_cache_all && - (params.interactive || params.interactive_first || - params.instruct)) { + (params.interactive || params.interactive_first || + params.instruct)) + { fprintf(stderr, "error: --prompt-cache-all not supported in interactive mode yet\n"); gpt_print_usage(argc, argv, default_params); exit(1); } - if (escape_prompt) { + if (escape_prompt) + { process_escapes(params.prompt); } return true; } -void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { +void gpt_print_usage(int /*argc*/, char **argv, const gpt_params ¶ms) +{ fprintf(stderr, "usage: %s [options]\n", argv[0]); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); @@ -415,10 +621,12 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); fprintf(stderr, " --perplexity compute perplexity over the prompt\n"); fprintf(stderr, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep); - if (llama_mlock_supported()) { + if (llama_mlock_supported()) + { fprintf(stderr, " --mlock force system to keep model in RAM rather than swapping or compressing\n"); } - if (llama_mmap_supported()) { + if (llama_mmap_supported()) + { fprintf(stderr, " --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n"); } fprintf(stderr, " -ngl N, --n-gpu-layers N\n"); @@ -432,29 +640,43 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, "\n"); } -std::string gpt_random_prompt(std::mt19937 & rng) { +std::string gpt_random_prompt(std::mt19937 &rng) +{ const int r = rng() % 10; - switch (r) { - case 0: return "So"; - case 1: return "Once upon a time"; - case 2: return "When"; - case 3: return "The"; - case 4: return "After"; - case 5: return "If"; - case 6: return "import"; - case 7: return "He"; - case 8: return "She"; - case 9: return "They"; - default: return "To"; + switch (r) + { + case 0: + return "So"; + case 1: + return "Once upon a time"; + case 2: + return "When"; + case 3: + return "The"; + case 4: + return "After"; + case 5: + return "If"; + case 6: + return "import"; + case 7: + return "He"; + case 8: + return "She"; + case 9: + return "They"; + default: + return "To"; } return "The"; } // TODO: not great allocating this every time -std::vector llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) { +std::vector llama_tokenize(struct llama_context *ctx, const std::string &text, bool add_bos) +{ // initialize to prompt numer of chars, since n_tokens <= n_prompt_chars - std::vector res(text.size() + (int) add_bos); + std::vector res(text.size() + (int)add_bos); const int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos); assert(n >= 0); res.resize(n); @@ -462,31 +684,35 @@ std::vector llama_tokenize(struct llama_context * ctx, const std::s return res; } -struct llama_context * llama_init_from_gpt_params(const gpt_params & params) { +struct llama_context *llama_init_from_gpt_params(const gpt_params ¶ms) +{ auto lparams = llama_context_default_params(); - lparams.n_ctx = params.n_ctx; + lparams.n_ctx = params.n_ctx; lparams.n_gpu_layers = params.n_gpu_layers; - lparams.seed = params.seed; - lparams.f16_kv = params.memory_f16; - lparams.use_mmap = params.use_mmap; - lparams.use_mlock = params.use_mlock; - lparams.logits_all = params.perplexity; - lparams.embedding = params.embedding; + lparams.seed = params.seed; + lparams.f16_kv = params.memory_f16; + lparams.use_mmap = params.use_mmap; + lparams.use_mlock = params.use_mlock; + lparams.logits_all = params.perplexity; + lparams.embedding = params.embedding; - llama_context * lctx = llama_init_from_file(params.model.c_str(), lparams); + llama_context *lctx = llama_init_from_file(params.model.c_str(), lparams); - if (lctx == NULL) { + if (lctx == NULL) + { fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); return NULL; } - if (!params.lora_adapter.empty()) { + if (!params.lora_adapter.empty()) + { int err = llama_apply_lora_from_file(lctx, params.lora_adapter.c_str(), params.lora_base.empty() ? NULL : params.lora_base.c_str(), params.n_threads); - if (err != 0) { + if (err != 0) + { fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); return NULL; } @@ -495,27 +721,33 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) { return lctx; } -void console_init(console_state & con_st) { +void console_init(console_state &con_st) +{ #if defined(_WIN32) // Windows-specific console initialization DWORD dwMode = 0; con_st.hConsole = GetStdHandle(STD_OUTPUT_HANDLE); - if (con_st.hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(con_st.hConsole, &dwMode)) { + if (con_st.hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(con_st.hConsole, &dwMode)) + { con_st.hConsole = GetStdHandle(STD_ERROR_HANDLE); - if (con_st.hConsole != INVALID_HANDLE_VALUE && (!GetConsoleMode(con_st.hConsole, &dwMode))) { + if (con_st.hConsole != INVALID_HANDLE_VALUE && (!GetConsoleMode(con_st.hConsole, &dwMode))) + { con_st.hConsole = NULL; } } - if (con_st.hConsole) { + if (con_st.hConsole) + { // Enable ANSI colors on Windows 10+ - if (con_st.use_color && !(dwMode & ENABLE_VIRTUAL_TERMINAL_PROCESSING)) { + if (con_st.use_color && !(dwMode & ENABLE_VIRTUAL_TERMINAL_PROCESSING)) + { SetConsoleMode(con_st.hConsole, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING); } // Set console output codepage to UTF8 SetConsoleOutputCP(CP_UTF8); } HANDLE hConIn = GetStdHandle(STD_INPUT_HANDLE); - if (hConIn != INVALID_HANDLE_VALUE && GetConsoleMode(hConIn, &dwMode)) { + if (hConIn != INVALID_HANDLE_VALUE && GetConsoleMode(hConIn, &dwMode)) + { // Set console input codepage to UTF16 _setmode(_fileno(stdin), _O_WTEXT); @@ -534,7 +766,8 @@ void console_init(console_state & con_st) { tcsetattr(STDIN_FILENO, TCSANOW, &new_termios); con_st.tty = fopen("/dev/tty", "w+"); - if (con_st.tty != nullptr) { + if (con_st.tty != nullptr) + { con_st.out = con_st.tty; } @@ -542,12 +775,14 @@ void console_init(console_state & con_st) { #endif } -void console_cleanup(console_state & con_st) { +void console_cleanup(console_state &con_st) +{ // Reset console color console_set_color(con_st, CONSOLE_COLOR_DEFAULT); #if !defined(_WIN32) - if (con_st.tty != nullptr) { + if (con_st.tty != nullptr) + { con_st.out = stdout; fclose(con_st.tty); con_st.tty = nullptr; @@ -558,48 +793,60 @@ void console_cleanup(console_state & con_st) { } /* Keep track of current color of output, and emit ANSI code if it changes. */ -void console_set_color(console_state & con_st, console_color_t color) { - if (con_st.use_color && con_st.color != color) { +void console_set_color(console_state &con_st, console_color_t color) +{ + if (con_st.use_color && con_st.color != color) + { fflush(stdout); - switch(color) { - case CONSOLE_COLOR_DEFAULT: - fprintf(con_st.out, ANSI_COLOR_RESET); - break; - case CONSOLE_COLOR_PROMPT: - fprintf(con_st.out, ANSI_COLOR_YELLOW); - break; - case CONSOLE_COLOR_USER_INPUT: - fprintf(con_st.out, ANSI_BOLD ANSI_COLOR_GREEN); - break; + switch (color) + { + case CONSOLE_COLOR_DEFAULT: + fprintf(con_st.out, ANSI_COLOR_RESET); + break; + case CONSOLE_COLOR_PROMPT: + fprintf(con_st.out, ANSI_COLOR_YELLOW); + break; + case CONSOLE_COLOR_USER_INPUT: + fprintf(con_st.out, ANSI_BOLD ANSI_COLOR_GREEN); + break; } con_st.color = color; fflush(con_st.out); } } -char32_t getchar32() { +char32_t getchar32() +{ #if defined(_WIN32) HANDLE hConsole = GetStdHandle(STD_INPUT_HANDLE); wchar_t high_surrogate = 0; - while (true) { + while (true) + { INPUT_RECORD record; DWORD count; - if (!ReadConsoleInputW(hConsole, &record, 1, &count) || count == 0) { + if (!ReadConsoleInputW(hConsole, &record, 1, &count) || count == 0) + { return WEOF; } - if (record.EventType == KEY_EVENT && record.Event.KeyEvent.bKeyDown) { + if (record.EventType == KEY_EVENT && record.Event.KeyEvent.bKeyDown) + { wchar_t wc = record.Event.KeyEvent.uChar.UnicodeChar; - if (wc == 0) { + if (wc == 0) + { continue; } - if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate + if ((wc >= 0xD800) && (wc <= 0xDBFF)) + { // Check if wc is a high surrogate high_surrogate = wc; continue; - } else if ((wc >= 0xDC00) && (wc <= 0xDFFF)) { // Check if wc is a low surrogate - if (high_surrogate != 0) { // Check if we have a high surrogate + } + else if ((wc >= 0xDC00) && (wc <= 0xDFFF)) + { // Check if wc is a low surrogate + if (high_surrogate != 0) + { // Check if we have a high surrogate return ((high_surrogate - 0xD800) << 10) + (wc - 0xDC00) + 0x10000; } } @@ -610,18 +857,22 @@ char32_t getchar32() { } #else wchar_t wc = getwchar(); - if (static_cast(wc) == WEOF) { + if (static_cast(wc) == WEOF) + { return WEOF; } #if WCHAR_MAX == 0xFFFF - if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate + if ((wc >= 0xD800) && (wc <= 0xDBFF)) + { // Check if wc is a high surrogate wchar_t low_surrogate = getwchar(); - if ((low_surrogate >= 0xDC00) && (low_surrogate <= 0xDFFF)) { // Check if the next wchar is a low surrogate + if ((low_surrogate >= 0xDC00) && (low_surrogate <= 0xDFFF)) + { // Check if the next wchar is a low surrogate return (static_cast(wc & 0x03FF) << 10) + (low_surrogate & 0x03FF) + 0x10000; } } - if ((wc >= 0xD800) && (wc <= 0xDFFF)) { // Invalid surrogate pair + if ((wc >= 0xD800) && (wc <= 0xDFFF)) + { // Invalid surrogate pair return 0xFFFD; // Return the replacement character U+FFFD } #endif @@ -630,17 +881,22 @@ char32_t getchar32() { #endif } -void pop_cursor(console_state & con_st) { +void pop_cursor(console_state &con_st) +{ #if defined(_WIN32) - if (con_st.hConsole != NULL) { + if (con_st.hConsole != NULL) + { CONSOLE_SCREEN_BUFFER_INFO bufferInfo; GetConsoleScreenBufferInfo(con_st.hConsole, &bufferInfo); COORD newCursorPosition = bufferInfo.dwCursorPosition; - if (newCursorPosition.X == 0) { + if (newCursorPosition.X == 0) + { newCursorPosition.X = bufferInfo.dwSize.X - 1; newCursorPosition.Y -= 1; - } else { + } + else + { newCursorPosition.X -= 1; } @@ -651,7 +907,8 @@ void pop_cursor(console_state & con_st) { putc('\b', con_st.out); } -int estimateWidth(char32_t codepoint) { +int estimateWidth(char32_t codepoint) +{ #if defined(_WIN32) return 1; #else @@ -659,10 +916,12 @@ int estimateWidth(char32_t codepoint) { #endif } -int put_codepoint(console_state & con_st, const char* utf8_codepoint, size_t length, int expectedWidth) { +int put_codepoint(console_state &con_st, const char *utf8_codepoint, size_t length, int expectedWidth) +{ #if defined(_WIN32) CONSOLE_SCREEN_BUFFER_INFO bufferInfo; - if (!GetConsoleScreenBufferInfo(con_st.hConsole, &bufferInfo)) { + if (!GetConsoleScreenBufferInfo(con_st.hConsole, &bufferInfo)) + { // go with the default return expectedWidth; } @@ -674,20 +933,23 @@ int put_codepoint(console_state & con_st, const char* utf8_codepoint, size_t len GetConsoleScreenBufferInfo(con_st.hConsole, &newBufferInfo); // Figure out our real position if we're in the last column - if (utf8_codepoint[0] != 0x09 && initialPosition.X == newBufferInfo.dwSize.X - 1) { + if (utf8_codepoint[0] != 0x09 && initialPosition.X == newBufferInfo.dwSize.X - 1) + { DWORD nNumberOfChars; WriteConsole(con_st.hConsole, &" \b", 2, &nNumberOfChars, NULL); GetConsoleScreenBufferInfo(con_st.hConsole, &newBufferInfo); } int width = newBufferInfo.dwCursorPosition.X - initialPosition.X; - if (width < 0) { + if (width < 0) + { width += newBufferInfo.dwSize.X; } return width; #else // we can trust expectedWidth if we've got one - if (expectedWidth >= 0 || con_st.tty == nullptr) { + if (expectedWidth >= 0 || con_st.tty == nullptr) + { fwrite(utf8_codepoint, length, 1, con_st.out); return expectedWidth; } @@ -702,12 +964,14 @@ int put_codepoint(console_state & con_st, const char* utf8_codepoint, size_t len fputs("\033[6n", con_st.tty); // Query cursor position results += fscanf(con_st.tty, "\033[%d;%dR", &y2, &x2); - if (results != 4) { + if (results != 4) + { return expectedWidth; } int width = x2 - x1; - if (width < 0) { + if (width < 0) + { // Calculate the width considering text wrapping struct winsize w; ioctl(STDOUT_FILENO, TIOCGWINSZ, &w); @@ -717,7 +981,8 @@ int put_codepoint(console_state & con_st, const char* utf8_codepoint, size_t len #endif } -void replace_last(console_state & con_st, char ch) { +void replace_last(console_state &con_st, char ch) +{ #if defined(_WIN32) pop_cursor(con_st); put_codepoint(con_st, &ch, 1, 1); @@ -726,44 +991,60 @@ void replace_last(console_state & con_st, char ch) { #endif } -void append_utf8(char32_t ch, std::string & out) { - if (ch <= 0x7F) { +void append_utf8(char32_t ch, std::string &out) +{ + if (ch <= 0x7F) + { out.push_back(static_cast(ch)); - } else if (ch <= 0x7FF) { + } + else if (ch <= 0x7FF) + { out.push_back(static_cast(0xC0 | ((ch >> 6) & 0x1F))); out.push_back(static_cast(0x80 | (ch & 0x3F))); - } else if (ch <= 0xFFFF) { + } + else if (ch <= 0xFFFF) + { out.push_back(static_cast(0xE0 | ((ch >> 12) & 0x0F))); out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); out.push_back(static_cast(0x80 | (ch & 0x3F))); - } else if (ch <= 0x10FFFF) { + } + else if (ch <= 0x10FFFF) + { out.push_back(static_cast(0xF0 | ((ch >> 18) & 0x07))); out.push_back(static_cast(0x80 | ((ch >> 12) & 0x3F))); out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); out.push_back(static_cast(0x80 | (ch & 0x3F))); - } else { + } + else + { // Invalid Unicode code point } } // Helper function to remove the last UTF-8 character from a string -void pop_back_utf8_char(std::string & line) { - if (line.empty()) { +void pop_back_utf8_char(std::string &line) +{ + if (line.empty()) + { return; } size_t pos = line.length() - 1; // Find the start of the last UTF-8 character (checking up to 4 bytes back) - for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) { - if ((line[pos] & 0xC0) != 0x80) break; // Found the start of the character + for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) + { + if ((line[pos] & 0xC0) != 0x80) + break; // Found the start of the character } line.erase(pos); } -bool console_readline(console_state & con_st, std::string & line) { +bool console_readline(console_state &con_st, std::string &line) +{ console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); - if (con_st.out != stdout) { + if (con_st.out != stdout) + { fflush(stdout); } @@ -773,60 +1054,77 @@ bool console_readline(console_state & con_st, std::string & line) { bool end_of_stream = false; char32_t input_char; - while (true) { + while (true) + { fflush(con_st.out); // Ensure all output is displayed before waiting for input input_char = getchar32(); - if (input_char == '\r' || input_char == '\n') { + if (input_char == '\r' || input_char == '\n') + { break; } - if (input_char == (char32_t) WEOF || input_char == 0x04 /* Ctrl+D*/) { + if (input_char == (char32_t)WEOF || input_char == 0x04 /* Ctrl+D*/) + { end_of_stream = true; break; } - if (is_special_char) { + if (is_special_char) + { console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); replace_last(con_st, line.back()); is_special_char = false; } - if (input_char == '\033') { // Escape sequence + if (input_char == '\033') + { // Escape sequence char32_t code = getchar32(); - if (code == '[' || code == 0x1B) { + if (code == '[' || code == 0x1B) + { // Discard the rest of the escape sequence - while ((code = getchar32()) != (char32_t) WEOF) { - if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') { + while ((code = getchar32()) != (char32_t)WEOF) + { + if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') + { break; } } } - } else if (input_char == 0x08 || input_char == 0x7F) { // Backspace - if (!widths.empty()) { + } + else if (input_char == 0x08 || input_char == 0x7F) + { // Backspace + if (!widths.empty()) + { int count; - do { + do + { count = widths.back(); widths.pop_back(); // Move cursor back, print space, and move cursor back again - for (int i = 0; i < count; i++) { + for (int i = 0; i < count; i++) + { replace_last(con_st, ' '); pop_cursor(con_st); } pop_back_utf8_char(line); } while (count == 0 && !widths.empty()); } - } else { + } + else + { int offset = line.length(); append_utf8(input_char, line); int width = put_codepoint(con_st, line.c_str() + offset, line.length() - offset, estimateWidth(input_char)); - if (width < 0) { + if (width < 0) + { width = 0; } widths.push_back(width); } - if (!line.empty() && (line.back() == '\\' || line.back() == '/')) { + if (!line.empty() && (line.back() == '\\' || line.back() == '/')) + { console_set_color(con_st, CONSOLE_COLOR_PROMPT); replace_last(con_st, line.back()); is_special_char = true; @@ -834,28 +1132,38 @@ bool console_readline(console_state & con_st, std::string & line) { } bool has_more = con_st.multiline_input; - if (is_special_char) { + if (is_special_char) + { replace_last(con_st, ' '); pop_cursor(con_st); char last = line.back(); line.pop_back(); - if (last == '\\') { + if (last == '\\') + { line += '\n'; fputc('\n', con_st.out); has_more = !has_more; - } else { + } + else + { // llama will just eat the single space, it won't act as a space - if (line.length() == 1 && line.back() == ' ') { + if (line.length() == 1 && line.back() == ' ') + { line.clear(); pop_cursor(con_st); } has_more = false; } - } else { - if (end_of_stream) { + } + else + { + if (end_of_stream) + { has_more = false; - } else { + } + else + { line += '\n'; fputc('\n', con_st.out); } diff --git a/examples/common.h b/examples/common.h index 2b66382a6..812a2fd18 100644 --- a/examples/common.h +++ b/examples/common.h @@ -74,6 +74,7 @@ struct gpt_params { }; bool gpt_params_parse(int argc, char ** argv, gpt_params & params); +bool gpt_params_parse_with_extra_check(int argc, char **argv, gpt_params ¶ms, std::vector *extra_args); void gpt_print_usage(int argc, char ** argv, const gpt_params & params); diff --git a/examples/grpc-server/grpc-server.cpp b/examples/grpc-server/grpc-server.cpp index 914a559f0..31c918b27 100644 --- a/examples/grpc-server/grpc-server.cpp +++ b/examples/grpc-server/grpc-server.cpp @@ -32,9 +32,6 @@ #include "message.grpc.pb.h" #endif -// ABSL_FLAG(uint16_t, port, 50051, "Server port for the service"); -// ABSL_FLAG(std::string, target, "localhost:50051", "Server address"); - using grpc::CallbackServerContext; using grpc::Server; using grpc::ServerAsyncWriter; @@ -149,6 +146,8 @@ public: params.antiprompt.clear(); no_show_words.clear(); num_tokens_predicted = 0; + // embd.clear(); + // n_remain = 0; // generated_text = ""; } @@ -250,6 +249,7 @@ public: { n_remain -= new_prompt_len; } + fprintf(stderr, "embd_inp size %d,%d\n", embd_inp.size(), params.n_ctx); if ((int)embd_inp.size() > params.n_ctx - 4) { return false; @@ -274,6 +274,7 @@ public: llama_token nextToken() { llama_token result = -1; + // fprintf(stderr, "embed size %d,%d,%d\n", embd.size(), embd_inp.size(), n_consumed); if (embd.size() > 0) { if (n_past + (int)embd.size() > params.n_ctx) @@ -482,6 +483,12 @@ public: } } + void printInfo() + { + fprintf(stderr, "embed size: %d\n", embd.size()); + fprintf(stderr, "embd_inp size: %d\n", embd_inp.size()); + } + private: gpt_params params; llama_context *ctx; @@ -523,6 +530,11 @@ class LlamaServiceImpl final : public LlamaGoService::CallbackService // fprintf(stderr, "on write done"); NextWrite(); } + void OnCancel() override + { + fprintf(stderr, "on cancel"); + delete this; + } private: CallbackServerContext *const ctx_; @@ -541,16 +553,20 @@ class LlamaServiceImpl final : public LlamaGoService::CallbackService std::lock_guard l(finish_mu_); auto result = llama_->doCompletion(); fprintf(stderr, "%s", result.c_str()); + response.set_status(llama::Status::RUNNING); response.set_output(result); StartWrite(&response); } else { { - response.set_status(llama::Status::FINISHED); std::lock_guard l(finish_mu_); - StartWriteLast(&response, grpc::WriteOptions()); + if (!finished_) + { + response.set_status(llama::Status::FINISHED); + StartWriteLast(&response, grpc::WriteOptions()); + } } // If we use WriteLast, we shouldn't wait before attempting Finish FinishOnce(Status::OK); @@ -571,13 +587,12 @@ class LlamaServiceImpl final : public LlamaGoService::CallbackService public: LlamaServiceImpl(LlamaServerContext *llama_) : llama(llama_) { - fprintf(stderr, "%s : new impl\n", __func__); } ServerWriteReactor *Answer( CallbackServerContext *context, const Job *request) { - fprintf(stderr, "%s : get answer\n", __func__); + fprintf(stderr, "%s : new answer request: %s\n", __func__, request->prompt().c_str()); llama->rewind(); // std::vector embeded = llama->complete(request->prompt()); Reactor *reactor = new Reactor(context, llama, request); @@ -625,6 +640,12 @@ void RunServer(uint16_t port, LlamaServerContext *llama) server->Wait(); } +// auto server_params_parse(server_params &sparams) +// { + +// return set_extra_params; +// } + bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_params ¶ms) { gpt_params default_params; @@ -652,64 +673,6 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para } sparams.hostname = argv[i]; } - else if (arg == "-s" || arg == "--seed") - { -#if defined(GGML_USE_CUBLAS) - fprintf(stderr, "WARNING: when using cuBLAS generation results are NOT guaranteed to be reproducible.\n"); -#endif - if (++i >= argc) - { - invalid_param = true; - break; - } - params.seed = std::stoi(argv[i]); - } - else if (arg == "-m" || arg == "--model") - { - if (++i >= argc) - { - invalid_param = true; - break; - } - params.model = argv[i]; - } - else if (arg == "--embedding") - { - params.embedding = true; - } - else if (arg == "-h" || arg == "--help") - { - server_print_usage(argc, argv, default_params); - exit(0); - } - else if (arg == "-c" || arg == "--ctx_size") - { - if (++i >= argc) - { - invalid_param = true; - break; - } - params.n_ctx = std::stoi(argv[i]); - } - else if (arg == "--memory_f32") - { - params.memory_f16 = false; - } - else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") - { - if (++i >= argc) - { - invalid_param = true; - break; - } - params.n_gpu_layers = std::stoi(argv[i]); - } - else - { - fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); - server_print_usage(argc, argv, default_params); - exit(1); - } } if (invalid_param) @@ -731,10 +694,18 @@ int main(int argc, char **argv) params.model = "ggml-model.bin"; params.n_ctx = 512; + // params.embedding = true; sparams.port = 8080; - if (gpt_params_parse(argc, argv, params) == false) + std::vector extra_args = {"--server", "--port"}; + + if (server_params_parse(argc, argv, sparams, params) == false) + { + return 1; + } + + if (gpt_params_parse_with_extra_check(argc, argv, params, &extra_args) == false) { return 1; }