diff --git a/common/arg.cpp b/common/arg.cpp index 27886b84e..017501310 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -768,15 +768,19 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-cnv", "--conversation"}, - string_format( - "run in conversation mode:\n" - "- does not print special tokens and suffix/prefix\n" - "- interactive mode is also enabled\n" - "(default: %s)", - params.conversation ? "true" : "false" - ), + "run in conversation mode:\n" + "- does not print special tokens and suffix/prefix\n" + "- interactive mode is also enabled\n" + "(default: auto enabled if chat template is available)", [](common_params & params) { - params.conversation = true; + params.conversation_mode = COMMON_CONVERSATION_MODE_ENABLED; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"-no-cnv", "--no-conversation"}, + "force disable conversation mode (default: false)", + [](common_params & params) { + params.conversation_mode = COMMON_CONVERSATION_MODE_DISABLED; } ).set_examples({LLAMA_EXAMPLE_MAIN})); add_opt(common_arg( diff --git a/common/common.h b/common/common.h index d523948b0..322c03a22 100644 --- a/common/common.h +++ b/common/common.h @@ -103,6 +103,12 @@ enum dimre_method { DIMRE_METHOD_MEAN, }; +enum common_conversation_mode { + COMMON_CONVERSATION_MODE_DISABLED = 0, + COMMON_CONVERSATION_MODE_ENABLED = 1, + COMMON_CONVERSATION_MODE_AUTO = 2, +}; + // sampling parameters struct common_params_sampling { uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler @@ -275,7 +281,6 @@ struct common_params { bool special = false; // enable special token output bool interactive = false; // interactive mode bool interactive_first = false; // wait for user input immediately - bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix) bool prompt_cache_all = false; // save user input and generations to prompt cache bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it @@ -301,6 +306,8 @@ struct common_params { ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V + common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO; + // multimodal models (see examples/llava) std::string mmproj = ""; // path to multimodal projector // NOLINT std::vector image; // path to image file(s) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 640b35c1d..7148e3fc3 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -30,6 +30,8 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif +static const char * DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant"; + static llama_context ** g_ctx; static llama_model ** g_model; static common_sampler ** g_smpl; @@ -204,8 +206,17 @@ int main(int argc, char ** argv) { LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx); } + // auto enable conversation mode if chat template is available + if ( + params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO + && (!common_get_builtin_chat_template(model).empty() || !params.chat_template.empty()) + ) { + LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__); + params.conversation_mode = COMMON_CONVERSATION_MODE_ENABLED; + } + // print chat template example in conversation mode - if (params.conversation) { + if (params.conversation_mode) { if (params.enable_chat_template) { LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, params.chat_template).c_str()); } else { @@ -252,8 +263,10 @@ int main(int argc, char ** argv) { std::vector embd_inp; { - auto prompt = (params.conversation && params.enable_chat_template && !params.prompt.empty()) - ? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode + auto prompt = (params.conversation_mode && params.enable_chat_template) + // format the system prompt in conversation mode (fallback to default if empty) + ? chat_add_and_format(model, chat_msgs, "system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt) + // otherwise use the prompt as is : params.prompt; if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { LOG_DBG("tokenize the prompt\n"); @@ -327,7 +340,7 @@ int main(int argc, char ** argv) { params.n_keep += add_bos; // always keep the BOS token } - if (params.conversation) { + if (params.conversation_mode) { params.interactive_first = true; } @@ -451,7 +464,11 @@ int main(int argc, char ** argv) { #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) LOG_INF( " - Press Ctrl+C to interject at any time.\n"); #endif - LOG_INF( "%s\n", control_message); + LOG_INF( "%s", control_message); + if (params.conversation_mode && params.enable_chat_template && params.prompt.empty()) { + LOG_INF( " - Using default system message. To change it, set a different value via -p PROMPT or -f FILE argument.\n"); + } + LOG_INF("\n"); is_interacting = params.interactive_first; } @@ -763,7 +780,7 @@ int main(int argc, char ** argv) { } // if current token is not EOG, we add it to current assistant message - if (params.conversation) { + if (params.conversation_mode) { const auto id = common_sampler_last(smpl); assistant_ss << common_token_to_piece(ctx, id, false); } @@ -771,7 +788,7 @@ int main(int argc, char ** argv) { if (n_past > 0 && is_interacting) { LOG_DBG("waiting for user input\n"); - if (params.conversation) { + if (params.conversation_mode) { LOG("\n> "); } @@ -781,7 +798,7 @@ int main(int argc, char ** argv) { } std::string buffer; - if (!params.input_prefix.empty() && !params.conversation) { + if (!params.input_prefix.empty() && !params.conversation_mode) { LOG_DBG("appending input prefix: '%s'\n", params.input_prefix.c_str()); LOG("%s", params.input_prefix.c_str()); } @@ -805,7 +822,7 @@ int main(int argc, char ** argv) { // Entering a empty line lets the user pass control back if (buffer.length() > 1) { // append input suffix if any - if (!params.input_suffix.empty() && !params.conversation) { + if (!params.input_suffix.empty() && !params.conversation_mode) { LOG_DBG("appending input suffix: '%s'\n", params.input_suffix.c_str()); LOG("%s", params.input_suffix.c_str()); } @@ -818,7 +835,7 @@ int main(int argc, char ** argv) { string_process_escapes(buffer); } - bool format_chat = params.conversation && params.enable_chat_template; + bool format_chat = params.conversation_mode && params.enable_chat_template; std::string user_inp = format_chat ? chat_add_and_format(model, chat_msgs, "user", std::move(buffer)) : std::move(buffer);