This commit is contained in:
ngxson 2024-04-24 15:31:46 +02:00
parent 588b72d950
commit f1a93548aa
2 changed files with 13 additions and 8 deletions

View file

@ -149,7 +149,7 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size()); res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
} }
const std::string formatted_chat(buf.data(), res); const std::string formatted_chat(buf.data());
LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}}); LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});

View file

@ -17106,7 +17106,7 @@ LLAMA_API int32_t llama_chat_get_model_template(
return -1; return -1;
} else { } else {
snprintf(buf, length, "%s", model_template.c_str()); snprintf(buf, length, "%s", model_template.c_str());
return model_template.size() + 1; return model_template.size();
} }
} }
@ -17162,7 +17162,7 @@ LLAMA_API llama_chat_template llama_chat_get_template_type(const char * tmpl) {
} }
LLAMA_API int32_t llama_chat_get_prefix( LLAMA_API int32_t llama_chat_get_prefix(
const llama_chat_template tmpl, const llama_chat_template ttmpl,
const char * role, const char * role,
const char * prev_role, const char * prev_role,
char * buf, char * buf,
@ -17170,6 +17170,7 @@ LLAMA_API int32_t llama_chat_get_prefix(
std::stringstream ss; std::stringstream ss;
std::string srole(role); std::string srole(role);
std::string sprev_role(prev_role == nullptr ? "" : prev_role); std::string sprev_role(prev_role == nullptr ? "" : prev_role);
// str_toupper converts a string to all upper case, example: "abc" ==> "ABC"
auto str_toupper = [](std::string & str) { auto str_toupper = [](std::string & str) {
std::string output(str); std::string output(str);
for (size_t i = 0; i < output.size(); i++) { for (size_t i = 0; i < output.size(); i++) {
@ -17177,12 +17178,14 @@ LLAMA_API int32_t llama_chat_get_prefix(
} }
return output; return output;
}; };
// str_tofirstcap transforms first letter to uppercase, example: "abc" ==> "Abc"
auto str_tofirstcap = [](std::string & str) { auto str_tofirstcap = [](std::string & str) {
std::string output(str); std::string output(str);
output[0] = toupper(output[0]); output[0] = toupper(output[0]);
return output; return output;
}; };
switch (tmpl) { // ttmpl means "typed template"
switch (ttmpl) {
case LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED: case LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED:
return -1; return -1;
case LLAMA_CHAT_TEMPLATE_CHATML: case LLAMA_CHAT_TEMPLATE_CHATML:
@ -17267,7 +17270,7 @@ LLAMA_API int32_t llama_chat_get_prefix(
} }
LLAMA_API int32_t llama_chat_get_postfix( LLAMA_API int32_t llama_chat_get_postfix(
const llama_chat_template tmpl, const llama_chat_template ttmpl,
const char * role, const char * role,
const char * prev_role, const char * prev_role,
char * buf, char * buf,
@ -17275,7 +17278,7 @@ LLAMA_API int32_t llama_chat_get_postfix(
std::stringstream ss; std::stringstream ss;
std::string srole(role); std::string srole(role);
std::string sprev_role(prev_role == nullptr ? "" : prev_role); std::string sprev_role(prev_role == nullptr ? "" : prev_role);
switch (tmpl) { switch (ttmpl) {
case LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED: case LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED:
return -1; return -1;
case LLAMA_CHAT_TEMPLATE_CHATML: case LLAMA_CHAT_TEMPLATE_CHATML:
@ -17345,8 +17348,8 @@ LLAMA_API int32_t llama_chat_get_postfix(
return output.size(); return output.size();
} }
LLAMA_API bool llama_chat_support_system_message(const llama_chat_template tmpl) { LLAMA_API bool llama_chat_support_system_message(const llama_chat_template ttmpl) {
switch (tmpl) { switch (ttmpl) {
case LLAMA_CHAT_TEMPLATE_CHATML: case LLAMA_CHAT_TEMPLATE_CHATML:
case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS_BOS: case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS_BOS:
case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS: case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS:
@ -17371,6 +17374,8 @@ LLAMA_API int32_t llama_chat_apply_template(
bool add_ass, bool add_ass,
char * buf, char * buf,
int32_t length) { int32_t length) {
// either model or tmpl must be given
GGML_ASSERT(model != nullptr || tmpl != nullptr);
std::string curr_tmpl(tmpl == nullptr ? "" : tmpl); std::string curr_tmpl(tmpl == nullptr ? "" : tmpl);
if (tmpl == nullptr) { if (tmpl == nullptr) {
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes