diff --git a/llama.cpp b/llama.cpp index 400713a5d..d1bbf7f85 100644 --- a/llama.cpp +++ b/llama.cpp @@ -17307,7 +17307,7 @@ LLAMA_API int32_t llama_chat_get_model_template( } } -LLAMA_API llama_chat_template llama_chat_get_typed_template(const char * tmpl) { +LLAMA_API enum llama_chat_template llama_chat_get_typed_template(const char * tmpl) { if (tmpl == nullptr) { return LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED; } @@ -17361,7 +17361,7 @@ LLAMA_API llama_chat_template llama_chat_get_typed_template(const char * tmpl) { } LLAMA_API int32_t llama_chat_get_prefix( - const llama_chat_template ttmpl, + const enum llama_chat_template ttmpl, const char * role, const char * prev_role, char * buf, @@ -17473,7 +17473,7 @@ LLAMA_API int32_t llama_chat_get_prefix( } LLAMA_API int32_t llama_chat_get_postfix( - const llama_chat_template ttmpl, + const enum llama_chat_template ttmpl, const char * role, const char * prev_role, char * buf, @@ -17554,7 +17554,7 @@ LLAMA_API int32_t llama_chat_get_postfix( return output.size(); } -LLAMA_API bool llama_chat_support_system_message(const llama_chat_template ttmpl) { +LLAMA_API bool llama_chat_support_system_message(const enum llama_chat_template ttmpl) { switch (ttmpl) { case LLAMA_CHAT_TEMPLATE_CHATML: case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS_BOS: @@ -17596,7 +17596,7 @@ LLAMA_API int32_t llama_chat_apply_template( } // detect template type - llama_chat_template ttmpl = llama_chat_get_typed_template(curr_tmpl.c_str()); + enum llama_chat_template ttmpl = llama_chat_get_typed_template(curr_tmpl.c_str()); bool support_system_message = llama_chat_support_system_message(ttmpl); if (ttmpl == LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED) { return -1; diff --git a/llama.h b/llama.h index 4382dbe58..6705ca323 100644 --- a/llama.h +++ b/llama.h @@ -895,7 +895,7 @@ extern "C" { /// Get the value of enum llama_chat_template based on given Jinja template /// @param tmpl Jinja template (a string) /// @return The correct value of enum llama_chat_template - LLAMA_API llama_chat_template llama_chat_get_typed_template(const char * tmpl); + LLAMA_API enum llama_chat_template llama_chat_get_typed_template(const char * tmpl); /// Get the format prefix for a given message (based on role) /// @param tmpl Use enum llama_chat_template @@ -905,11 +905,11 @@ extern "C" { /// @param length The size of the allocated buffer /// @return The total number of bytes of the output string LLAMA_API int32_t llama_chat_get_prefix( - const llama_chat_template tmpl, - const char * role, - const char * prev_role, - char * buf, - int32_t length); + const enum llama_chat_template tmpl, + const char * role, + const char * prev_role, + char * buf, + int32_t length); /// Get the format postfix for a given message (based on role) /// @param tmpl Use enum llama_chat_template @@ -919,14 +919,14 @@ extern "C" { /// @param length The size of the allocated buffer /// @return The total number of bytes of the output string LLAMA_API int32_t llama_chat_get_postfix( - const llama_chat_template tmpl, - const char * role, - const char * prev_role, - char * buf, - int32_t length); + const enum llama_chat_template tmpl, + const char * role, + const char * prev_role, + char * buf, + int32_t length); /// Check if a given template support system message or not - LLAMA_API bool llama_chat_support_system_message(const llama_chat_template tmpl); + LLAMA_API bool llama_chat_support_system_message(const enum llama_chat_template tmpl); // // Grammar