This commit is contained in:
ngxson 2024-04-24 15:31:46 +02:00
parent 588b72d950
commit f1a93548aa
2 changed files with 13 additions and 8 deletions

View file

@ -149,7 +149,7 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
}
const std::string formatted_chat(buf.data(), res);
const std::string formatted_chat(buf.data());
LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});

View file

@ -17106,7 +17106,7 @@ LLAMA_API int32_t llama_chat_get_model_template(
return -1;
} else {
snprintf(buf, length, "%s", model_template.c_str());
return model_template.size() + 1;
return model_template.size();
}
}
@ -17162,7 +17162,7 @@ LLAMA_API llama_chat_template llama_chat_get_template_type(const char * tmpl) {
}
LLAMA_API int32_t llama_chat_get_prefix(
const llama_chat_template tmpl,
const llama_chat_template ttmpl,
const char * role,
const char * prev_role,
char * buf,
@ -17170,6 +17170,7 @@ LLAMA_API int32_t llama_chat_get_prefix(
std::stringstream ss;
std::string srole(role);
std::string sprev_role(prev_role == nullptr ? "" : prev_role);
// str_toupper converts a string to all upper case, example: "abc" ==> "ABC"
auto str_toupper = [](std::string & str) {
std::string output(str);
for (size_t i = 0; i < output.size(); i++) {
@ -17177,12 +17178,14 @@ LLAMA_API int32_t llama_chat_get_prefix(
}
return output;
};
// str_tofirstcap transforms first letter to uppercase, example: "abc" ==> "Abc"
auto str_tofirstcap = [](std::string & str) {
std::string output(str);
output[0] = toupper(output[0]);
return output;
};
switch (tmpl) {
// ttmpl means "typed template"
switch (ttmpl) {
case LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED:
return -1;
case LLAMA_CHAT_TEMPLATE_CHATML:
@ -17267,7 +17270,7 @@ LLAMA_API int32_t llama_chat_get_prefix(
}
LLAMA_API int32_t llama_chat_get_postfix(
const llama_chat_template tmpl,
const llama_chat_template ttmpl,
const char * role,
const char * prev_role,
char * buf,
@ -17275,7 +17278,7 @@ LLAMA_API int32_t llama_chat_get_postfix(
std::stringstream ss;
std::string srole(role);
std::string sprev_role(prev_role == nullptr ? "" : prev_role);
switch (tmpl) {
switch (ttmpl) {
case LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED:
return -1;
case LLAMA_CHAT_TEMPLATE_CHATML:
@ -17345,8 +17348,8 @@ LLAMA_API int32_t llama_chat_get_postfix(
return output.size();
}
LLAMA_API bool llama_chat_support_system_message(const llama_chat_template tmpl) {
switch (tmpl) {
LLAMA_API bool llama_chat_support_system_message(const llama_chat_template ttmpl) {
switch (ttmpl) {
case LLAMA_CHAT_TEMPLATE_CHATML:
case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS_BOS:
case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS:
@ -17371,6 +17374,8 @@ LLAMA_API int32_t llama_chat_apply_template(
bool add_ass,
char * buf,
int32_t length) {
// either model or tmpl must be given
GGML_ASSERT(model != nullptr || tmpl != nullptr);
std::string curr_tmpl(tmpl == nullptr ? "" : tmpl);
if (tmpl == nullptr) {
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes