cli : auto activate conversation mode if chat template is detected
This commit is contained in:
parent
1244cdcf14
commit
9d7d5f21f8
3 changed files with 47 additions and 19 deletions
|
@ -768,15 +768,19 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}));
|
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"-cnv", "--conversation"},
|
{"-cnv", "--conversation"},
|
||||||
string_format(
|
"run in conversation mode:\n"
|
||||||
"run in conversation mode:\n"
|
"- does not print special tokens and suffix/prefix\n"
|
||||||
"- does not print special tokens and suffix/prefix\n"
|
"- interactive mode is also enabled\n"
|
||||||
"- interactive mode is also enabled\n"
|
"(default: auto enabled if chat template is available)",
|
||||||
"(default: %s)",
|
|
||||||
params.conversation ? "true" : "false"
|
|
||||||
),
|
|
||||||
[](common_params & params) {
|
[](common_params & params) {
|
||||||
params.conversation = true;
|
params.conversation_mode = COMMON_CONVERSATION_MODE_ENABLED;
|
||||||
|
}
|
||||||
|
).set_examples({LLAMA_EXAMPLE_MAIN}));
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"-no-cnv", "--no-conversation"},
|
||||||
|
"force disable conversation mode (default: false)",
|
||||||
|
[](common_params & params) {
|
||||||
|
params.conversation_mode = COMMON_CONVERSATION_MODE_DISABLED;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_MAIN}));
|
).set_examples({LLAMA_EXAMPLE_MAIN}));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
|
|
|
@ -103,6 +103,12 @@ enum dimre_method {
|
||||||
DIMRE_METHOD_MEAN,
|
DIMRE_METHOD_MEAN,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum common_conversation_mode {
|
||||||
|
COMMON_CONVERSATION_MODE_DISABLED = 0,
|
||||||
|
COMMON_CONVERSATION_MODE_ENABLED = 1,
|
||||||
|
COMMON_CONVERSATION_MODE_AUTO = 2,
|
||||||
|
};
|
||||||
|
|
||||||
// sampling parameters
|
// sampling parameters
|
||||||
struct common_params_sampling {
|
struct common_params_sampling {
|
||||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
||||||
|
@ -275,7 +281,6 @@ struct common_params {
|
||||||
bool special = false; // enable special token output
|
bool special = false; // enable special token output
|
||||||
bool interactive = false; // interactive mode
|
bool interactive = false; // interactive mode
|
||||||
bool interactive_first = false; // wait for user input immediately
|
bool interactive_first = false; // wait for user input immediately
|
||||||
bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix)
|
|
||||||
bool prompt_cache_all = false; // save user input and generations to prompt cache
|
bool prompt_cache_all = false; // save user input and generations to prompt cache
|
||||||
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it
|
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it
|
||||||
|
|
||||||
|
@ -301,6 +306,8 @@ struct common_params {
|
||||||
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
|
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
|
||||||
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
|
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
|
||||||
|
|
||||||
|
common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO;
|
||||||
|
|
||||||
// multimodal models (see examples/llava)
|
// multimodal models (see examples/llava)
|
||||||
std::string mmproj = ""; // path to multimodal projector // NOLINT
|
std::string mmproj = ""; // path to multimodal projector // NOLINT
|
||||||
std::vector<std::string> image; // path to image file(s)
|
std::vector<std::string> image; // path to image file(s)
|
||||||
|
|
|
@ -30,6 +30,8 @@
|
||||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
static const char * DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant";
|
||||||
|
|
||||||
static llama_context ** g_ctx;
|
static llama_context ** g_ctx;
|
||||||
static llama_model ** g_model;
|
static llama_model ** g_model;
|
||||||
static common_sampler ** g_smpl;
|
static common_sampler ** g_smpl;
|
||||||
|
@ -204,8 +206,17 @@ int main(int argc, char ** argv) {
|
||||||
LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx);
|
LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// auto enable conversation mode if chat template is available
|
||||||
|
if (
|
||||||
|
params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO
|
||||||
|
&& (!common_get_builtin_chat_template(model).empty() || !params.chat_template.empty())
|
||||||
|
) {
|
||||||
|
LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__);
|
||||||
|
params.conversation_mode = COMMON_CONVERSATION_MODE_ENABLED;
|
||||||
|
}
|
||||||
|
|
||||||
// print chat template example in conversation mode
|
// print chat template example in conversation mode
|
||||||
if (params.conversation) {
|
if (params.conversation_mode) {
|
||||||
if (params.enable_chat_template) {
|
if (params.enable_chat_template) {
|
||||||
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, params.chat_template).c_str());
|
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, params.chat_template).c_str());
|
||||||
} else {
|
} else {
|
||||||
|
@ -252,8 +263,10 @@ int main(int argc, char ** argv) {
|
||||||
std::vector<llama_token> embd_inp;
|
std::vector<llama_token> embd_inp;
|
||||||
|
|
||||||
{
|
{
|
||||||
auto prompt = (params.conversation && params.enable_chat_template && !params.prompt.empty())
|
auto prompt = (params.conversation_mode && params.enable_chat_template)
|
||||||
? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode
|
// format the system prompt in conversation mode (fallback to default if empty)
|
||||||
|
? chat_add_and_format(model, chat_msgs, "system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt)
|
||||||
|
// otherwise use the prompt as is
|
||||||
: params.prompt;
|
: params.prompt;
|
||||||
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
|
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
|
||||||
LOG_DBG("tokenize the prompt\n");
|
LOG_DBG("tokenize the prompt\n");
|
||||||
|
@ -327,7 +340,7 @@ int main(int argc, char ** argv) {
|
||||||
params.n_keep += add_bos; // always keep the BOS token
|
params.n_keep += add_bos; // always keep the BOS token
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.conversation) {
|
if (params.conversation_mode) {
|
||||||
params.interactive_first = true;
|
params.interactive_first = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -451,7 +464,11 @@ int main(int argc, char ** argv) {
|
||||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
||||||
LOG_INF( " - Press Ctrl+C to interject at any time.\n");
|
LOG_INF( " - Press Ctrl+C to interject at any time.\n");
|
||||||
#endif
|
#endif
|
||||||
LOG_INF( "%s\n", control_message);
|
LOG_INF( "%s", control_message);
|
||||||
|
if (params.conversation_mode && params.enable_chat_template && params.prompt.empty()) {
|
||||||
|
LOG_INF( " - Using default system message. To change it, set a different value via -p PROMPT or -f FILE argument.\n");
|
||||||
|
}
|
||||||
|
LOG_INF("\n");
|
||||||
|
|
||||||
is_interacting = params.interactive_first;
|
is_interacting = params.interactive_first;
|
||||||
}
|
}
|
||||||
|
@ -763,7 +780,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// if current token is not EOG, we add it to current assistant message
|
// if current token is not EOG, we add it to current assistant message
|
||||||
if (params.conversation) {
|
if (params.conversation_mode) {
|
||||||
const auto id = common_sampler_last(smpl);
|
const auto id = common_sampler_last(smpl);
|
||||||
assistant_ss << common_token_to_piece(ctx, id, false);
|
assistant_ss << common_token_to_piece(ctx, id, false);
|
||||||
}
|
}
|
||||||
|
@ -771,7 +788,7 @@ int main(int argc, char ** argv) {
|
||||||
if (n_past > 0 && is_interacting) {
|
if (n_past > 0 && is_interacting) {
|
||||||
LOG_DBG("waiting for user input\n");
|
LOG_DBG("waiting for user input\n");
|
||||||
|
|
||||||
if (params.conversation) {
|
if (params.conversation_mode) {
|
||||||
LOG("\n> ");
|
LOG("\n> ");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -781,7 +798,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string buffer;
|
std::string buffer;
|
||||||
if (!params.input_prefix.empty() && !params.conversation) {
|
if (!params.input_prefix.empty() && !params.conversation_mode) {
|
||||||
LOG_DBG("appending input prefix: '%s'\n", params.input_prefix.c_str());
|
LOG_DBG("appending input prefix: '%s'\n", params.input_prefix.c_str());
|
||||||
LOG("%s", params.input_prefix.c_str());
|
LOG("%s", params.input_prefix.c_str());
|
||||||
}
|
}
|
||||||
|
@ -805,7 +822,7 @@ int main(int argc, char ** argv) {
|
||||||
// Entering a empty line lets the user pass control back
|
// Entering a empty line lets the user pass control back
|
||||||
if (buffer.length() > 1) {
|
if (buffer.length() > 1) {
|
||||||
// append input suffix if any
|
// append input suffix if any
|
||||||
if (!params.input_suffix.empty() && !params.conversation) {
|
if (!params.input_suffix.empty() && !params.conversation_mode) {
|
||||||
LOG_DBG("appending input suffix: '%s'\n", params.input_suffix.c_str());
|
LOG_DBG("appending input suffix: '%s'\n", params.input_suffix.c_str());
|
||||||
LOG("%s", params.input_suffix.c_str());
|
LOG("%s", params.input_suffix.c_str());
|
||||||
}
|
}
|
||||||
|
@ -818,7 +835,7 @@ int main(int argc, char ** argv) {
|
||||||
string_process_escapes(buffer);
|
string_process_escapes(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool format_chat = params.conversation && params.enable_chat_template;
|
bool format_chat = params.conversation_mode && params.enable_chat_template;
|
||||||
std::string user_inp = format_chat
|
std::string user_inp = format_chat
|
||||||
? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
|
? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
|
||||||
: std::move(buffer);
|
: std::move(buffer);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue